Source code for transformers.modeling_tf_utils

# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF general model utils."""
import functools
import logging
import os

import h5py
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.saving import hdf5_format

from .configuration_utils import PretrainedConfig
from .file_utils import DUMMY_INPUTS, TF2_WEIGHTS_NAME, WEIGHTS_NAME, cached_path, hf_bucket_url, is_remote_url
from .generation_tf_utils import TFGenerationMixin
from .modeling_tf_pytorch_utils import load_pytorch_checkpoint_in_tf2_model


logger = logging.getLogger(__name__)


class TFModelUtilsMixin:
    """
    A few utilities for `tf.keras.Model`s, to be used as a mixin.
    """

    def num_parameters(self, only_trainable: bool = False) -> int:
        """
        Get number of (optionally, trainable) parameters in the model.
        """
        if only_trainable:
            return int(sum(np.prod(w.shape.as_list()) for w in self.trainable_variables))
        else:
            return self.count_params()


def keras_serializable(cls):
    """
    Decorate a Keras Layer class to support Keras serialization.

    This is done by:
    1. adding a `transformers_config` dict to the Keras config dictionary in `get_config` (called by Keras at
       serialization time
    2. wrapping `__init__` to accept that `transformers_config` dict (passed by Keras at deserialization time) and
       convert it to a config object for the actual layer initializer
    3. registering the class as a custom object in Keras (if the Tensorflow version supports this), so that it does
       not need to be supplied in `custom_objects` in the call to `tf.keras.models.load_model`

    :param cls: a tf.keras.layers.Layers subclass that accepts a `config` argument to its initializer (typically a
                `TF*MainLayer` class in this project)
    :return: the same class object, with modifications for Keras deserialization.
    """
    initializer = cls.__init__

    config_class = getattr(cls, "config_class", None)
    if config_class is None:
        raise AttributeError("Must set `config_class` to use @keras_serializable")

    @functools.wraps(initializer)
    def wrapped_init(self, *args, **kwargs):
        transformers_config = kwargs.pop("transformers_config", None)
        config = args[0] if args and isinstance(args[0], PretrainedConfig) else kwargs.get("config", None)
        if config is not None and transformers_config is not None:
            raise ValueError("Must pass either `config` or `transformers_config`, not both")
        elif config is not None:
            # normal layer construction, call with unchanged args (config is already in there)
            initializer(self, *args, **kwargs)
        elif transformers_config is not None:
            # Keras deserialization, convert dict to config
            config = config_class.from_dict(transformers_config)
            initializer(self, config, *args, **kwargs)
        else:
            raise ValueError("Must pass either `config` (PretrainedConfig) or `transformers_config` (dict)")
        self._transformers_config = config
        self._kwargs = kwargs

    cls.__init__ = wrapped_init

    if not hasattr(cls, "get_config"):
        raise TypeError("Only use @keras_serializable on tf.keras.layers.Layer subclasses")
    if hasattr(cls.get_config, "_is_default"):

        def get_config(self):
            cfg = super(cls, self).get_config()
            cfg["transformers_config"] = self._transformers_config.to_dict()
            cfg.update(self._kwargs)
            return cfg

        cls.get_config = get_config

    cls._keras_serializable = True
    if hasattr(tf.keras.utils, "register_keras_serializable"):
        cls = tf.keras.utils.register_keras_serializable()(cls)
    return cls


class TFQuestionAnsweringLoss:
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        start_loss = loss_fn(labels["start_position"], logits[0])
        end_loss = loss_fn(labels["end_position"], logits[1])

        return (start_loss + end_loss) / 2.0


class TFTokenClassificationLoss:
    def compute_loss(self, labels, logits):
        loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        active_loss = tf.reshape(labels, (-1,)) != -1
        reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
        labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)

        return loss_fn(labels, reduced_logits)


class TFSequenceClassificationLoss:
    def compute_loss(self, labels, logits):
        if shape_list(logits)[1] == 1:
            loss_fn = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        else:
            loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
                from_logits=True, reduction=tf.keras.losses.Reduction.NONE
            )

        return loss_fn(labels, logits)


TFMultipleChoiceLoss = TFSequenceClassificationLoss


[docs]class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): r""" Base class for all TF models. :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. Class attributes (overridden by derived classes): - ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: - ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`, - ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`, - ``path``: a path (string) to the TensorFlow checkpoint. - ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. """ config_class = None base_model_prefix = "" @property def dummy_inputs(self): """ Dummy inputs to build the network. Returns: tf.Tensor with dummy inputs """ return {"input_ids": tf.constant(DUMMY_INPUTS)} def __init__(self, config, *inputs, **kwargs): super().__init__(*inputs, **kwargs) if not isinstance(config, PretrainedConfig): raise ValueError( "Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. " "To create a model from a pretrained model use " "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format( self.__class__.__name__, self.__class__.__name__ ) ) # Save config in model self.config = config
[docs] def get_input_embeddings(self): """ Returns the model's input embeddings. Returns: :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states. """ base_model = getattr(self, self.base_model_prefix, self) if base_model is not self: return base_model.get_input_embeddings() else: raise NotImplementedError
[docs] def set_input_embeddings(self, value): """ Set model's input embeddings Args: value (:obj:`tf.keras.layers.Layer`): A module mapping vocabulary to hidden states. """ base_model = getattr(self, self.base_model_prefix, self) if base_model is not self: base_model.set_input_embeddings(value) else: raise NotImplementedError
[docs] def get_output_embeddings(self): """ Returns the model's output embeddings. Returns: :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary. """ return None # Overwrite for models with output embeddings
[docs] def resize_token_embeddings(self, new_num_tokens=None): """ Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. Arguments: new_num_tokens: (`optional`) int: New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end. If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model. Return: ``tf.Variable`` Pointer to the input tokens Embeddings Module of the model """ model_embeds = self._resize_token_embeddings(new_num_tokens) if new_num_tokens is None: return model_embeds return model_embeds
def _resize_token_embeddings(self, new_num_tokens): # get_input_embeddings and set_input_embeddings need to be implemented in base layer. base_model = getattr(self, self.base_model_prefix, self) old_embeddings = base_model.get_input_embeddings() new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) base_model.set_input_embeddings(new_embeddings) # Update base model and current model config self.config.vocab_size = new_num_tokens base_model.vocab_size = new_num_tokens return base_model.get_input_embeddings() def _get_word_embeddings(self, embeddings): if hasattr(embeddings, "word_embeddings"): # TFBertEmbeddings, TFAlbertEmbeddings, TFElectraEmbeddings return embeddings.word_embeddings elif hasattr(embeddings, "weight"): # TFSharedEmbeddings return embeddings.weight else: raise ValueError("word embedding is not defined.") def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): """ Build a resized Embedding Variable from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end. Args: new_num_tokens: (`optional`) int New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end If not provided or None: return the provided token Embedding Module. Return: ``tf.Variable`` Pointer to the resized word Embedding Module or the old Embedding Module if new_num_tokens is None """ word_embeddings = self._get_word_embeddings(old_embeddings) if new_num_tokens is None: return word_embeddings old_num_tokens, old_embedding_dim = word_embeddings.shape if old_num_tokens == new_num_tokens: return word_embeddings # initialize new embeddings # todo: initializer range is not always passed in config. init_range = getattr(self.config, "initializer_range", 0.02) new_embeddings = self.add_weight( "weight", shape=[new_num_tokens, old_embedding_dim], initializer=get_initializer(init_range), dtype=tf.float32, ) init_weights = new_embeddings.numpy() # Copy token embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) init_weights[:num_tokens_to_copy] = word_embeddings[:num_tokens_to_copy, :] new_embeddings.assign(init_weights) return new_embeddings
[docs] def prune_heads(self, heads_to_prune): """ Prunes heads of the base model. Arguments: heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). """ raise NotImplementedError
[docs] def save_pretrained(self, save_directory): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method. """ if os.path.isfile(save_directory): logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) return os.makedirs(save_directory, exist_ok=True) # Save configuration file self.config.save_pretrained(save_directory) # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, TF2_WEIGHTS_NAME) self.save_weights(output_model_file) logger.info("Model weights saved in {}".format(output_model_file))
[docs] @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning task. The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. Parameters: pretrained_model_name_or_path: either: - a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. - a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``. - a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. model_args: (`optional`) Sequence of positional arguments: All remaning positional arguments will be passed to the underlying model's ``__init__`` method config: (`optional`) one of: - an instance of a class derived from :class:`~transformers.PretrainedConfig`, or - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: - the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. - the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. from_pt: (`optional`) boolean, default False: Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument). cache_dir: (`optional`) string: Path to a directory in which a downloaded pre-trained model configuration should be cached if the standard cache should not be used. force_download: (`optional`) boolean, default False: Force to (re-)download the model weights and configuration files and override the cached versions if they exists. resume_download: (`optional`) boolean, default False: Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. proxies: (`optional`) dict, default None: A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. The proxies are used on each request. output_loading_info: (`optional`) boolean: Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. kwargs: (`optional`) Remaining dictionary of keyword arguments: Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: - If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. Examples:: # For example purposes. Not runnable. model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading assert model.config.output_attention == True # Loading from a TF checkpoint file instead of a PyTorch model (slower) config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) """ config = kwargs.pop("config", None) cache_dir = kwargs.pop("cache_dir", None) from_pt = kwargs.pop("from_pt", False) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_cdn = kwargs.pop("use_cdn", True) # Load config if we don't provide a configuration if not isinstance(config, PretrainedConfig): config_path = config if config is not None else pretrained_model_name_or_path config, model_kwargs = cls.config_class.from_pretrained( config_path, *model_args, cache_dir=cache_dir, return_unused_kwargs=True, force_download=force_download, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, **kwargs, ) else: model_kwargs = kwargs # Load model if pretrained_model_name_or_path is not None: if os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME)): # Load from a TF 2.0 checkpoint archive_file = os.path.join(pretrained_model_name_or_path, TF2_WEIGHTS_NAME) elif from_pt and os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): # Load from a PyTorch checkpoint archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) else: raise EnvironmentError( "Error no file named {} found in directory {} or `from_pt` set to False".format( [WEIGHTS_NAME, TF2_WEIGHTS_NAME], pretrained_model_name_or_path ) ) elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): archive_file = pretrained_model_name_or_path elif os.path.isfile(pretrained_model_name_or_path + ".index"): archive_file = pretrained_model_name_or_path + ".index" else: archive_file = hf_bucket_url( pretrained_model_name_or_path, filename=(WEIGHTS_NAME if from_pt else TF2_WEIGHTS_NAME), use_cdn=use_cdn, ) try: # Load from URL or cache if already cached resolved_archive_file = cached_path( archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies, resume_download=resume_download, local_files_only=local_files_only, ) if resolved_archive_file is None: raise EnvironmentError except EnvironmentError: msg = ( f"Can't load weights for '{pretrained_model_name_or_path}'. Make sure that:\n\n" f"- '{pretrained_model_name_or_path}' is a correct model identifier listed on 'https://huggingface.co/models'\n\n" f"- or '{pretrained_model_name_or_path}' is the correct path to a directory containing a file named one of {TF2_WEIGHTS_NAME}, {WEIGHTS_NAME}.\n\n" ) raise EnvironmentError(msg) if resolved_archive_file == archive_file: logger.info("loading weights file {}".format(archive_file)) else: logger.info("loading weights file {} from cache at {}".format(archive_file, resolved_archive_file)) else: resolved_archive_file = None # Instantiate model. model = cls(config, *model_args, **model_kwargs) if from_pt: # Load from a PyTorch checkpoint return load_pytorch_checkpoint_in_tf2_model(model, resolved_archive_file, allow_missing_keys=True) model(model.dummy_inputs, training=False) # build the network with dummy inputs assert os.path.isfile(resolved_archive_file), "Error retrieving file {}".format(resolved_archive_file) # 'by_name' allow us to do transfer learning by skipping/adding layers # see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357 try: model.load_weights(resolved_archive_file, by_name=True) except OSError: raise OSError( "Unable to load weights from h5 file. " "If you tried to load a TF 2.0 model from a PyTorch checkpoint, please set from_pt=True. " ) model(model.dummy_inputs, training=False) # Make sure restore ops are run # Check if the models are the same to output loading informations with h5py.File(resolved_archive_file, "r") as f: if "layer_names" not in f.attrs and "model_weights" in f: f = f["model_weights"] hdf5_layer_names = set(hdf5_format.load_attributes_from_hdf5_group(f, "layer_names")) model_layer_names = set(layer.name for layer in model.layers) missing_keys = list(model_layer_names - hdf5_layer_names) unexpected_keys = list(hdf5_layer_names - model_layer_names) error_msgs = [] if len(unexpected_keys) > 0: logger.warning( f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " f"initializing {model.__class__.__name__}: {unexpected_keys}\n" f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n" f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect " f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.warning( f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n" f"If your task is similar to the task the model of the ckeckpoint was trained on, " f"you can already use {model.__class__.__name__} for predictions without further training." ) if len(error_msgs) > 0: raise RuntimeError( "Error(s) in loading weights for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs)) ) if output_loading_info: loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs} return model, loading_info return model
class TFConv1D(tf.keras.layers.Layer): def __init__(self, nf, nx, initializer_range=0.02, **kwargs): """ TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2) Basically works like a Linear layer but the weights are transposed """ super().__init__(**kwargs) self.nf = nf self.nx = nx self.initializer_range = initializer_range def build(self, input_shape): self.weight = self.add_weight( "weight", shape=[self.nx, self.nf], initializer=get_initializer(self.initializer_range) ) self.bias = self.add_weight("bias", shape=[1, self.nf], initializer=tf.zeros_initializer()) def call(self, x): bz, sl = shape_list(x)[:2] x = tf.reshape(x, [-1, self.nx]) x = tf.matmul(x, self.weight) + self.bias x = tf.reshape(x, [bz, sl, self.nf]) return x class TFSharedEmbeddings(tf.keras.layers.Layer): """Construct shared token embeddings. """ def __init__(self, vocab_size, hidden_size, initializer_range=None, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range def build(self, input_shape): """Build shared token embedding layer Shared weights logic adapted from https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 """ self.weight = self.add_weight( "weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range) ) super().build(input_shape) def get_config(self): config = { "vocab_size": self.vocab_size, "hidden_size": self.hidden_size, "initializer_range": self.initializer_range, } base_config = super().get_config() return dict(list(base_config.items()) + list(config.items())) def call(self, inputs, mode="embedding"): """Get token embeddings of inputs. Args: inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids) mode: string, a valid value is one of "embedding" and "linear". Returns: outputs: (1) If mode == "embedding", output embedding tensor, float32 with shape [batch_size, length, embedding_size]; (2) mode == "linear", output linear tensor, float32 with shape [batch_size, length, vocab_size]. Raises: ValueError: if mode is not valid. Shared weights logic adapted from https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24 """ if mode == "embedding": return self._embedding(inputs) elif mode == "linear": return self._linear(inputs) else: raise ValueError("mode {} is not valid.".format(mode)) def _embedding(self, input_ids): """Applies embedding based on inputs tensor.""" return tf.gather(self.weight, input_ids) def _linear(self, inputs): """Computes logits by running inputs through a linear layer. Args: inputs: A float32 tensor with shape [..., hidden_size] Returns: float32 tensor with shape [..., vocab_size]. """ first_dims = shape_list(inputs)[:-1] x = tf.reshape(inputs, [-1, self.hidden_size]) logits = tf.matmul(x, self.weight, transpose_b=True) return tf.reshape(logits, first_dims + [self.vocab_size]) class TFSequenceSummary(tf.keras.layers.Layer): r""" Compute a single vector summary of a sequence hidden states according to various possibilities: Args of the config class: summary_type: - 'last' => [default] take the last token hidden state (like XLNet) - 'first' => take the first token hidden state (like Bert) - 'mean' => take the mean of all tokens hidden states - 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2) - 'attn' => Not implemented now, use multi-head attention summary_use_proj: Add a projection after the vector extraction summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False. summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default summary_first_dropout: Add a dropout before the projection and activation summary_last_dropout: Add a dropout after the projection and activation """ def __init__(self, config, initializer_range=0.02, **kwargs): super().__init__(**kwargs) self.summary_type = config.summary_type if hasattr(config, "summary_use_proj") else "last" if self.summary_type == "attn": # We should use a standard multi-head attention module with absolute positional embedding for that. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276 # We can probably just use the multi-head attention module of PyTorch >=1.1.0 raise NotImplementedError self.has_summary = hasattr(config, "summary_use_proj") and config.summary_use_proj if self.has_summary: if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0: num_classes = config.num_labels else: num_classes = config.hidden_size self.summary = tf.keras.layers.Dense( num_classes, kernel_initializer=get_initializer(initializer_range), name="summary" ) self.has_activation = hasattr(config, "summary_activation") and config.summary_activation == "tanh" if self.has_activation: self.activation = tf.keras.activations.tanh self.has_first_dropout = hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0 if self.has_first_dropout: self.first_dropout = tf.keras.layers.Dropout(config.summary_first_dropout) self.has_last_dropout = hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0 if self.has_last_dropout: self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout) def call(self, inputs, training=False): """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. cls_index: [optional] position of the classification token if summary_type == 'cls_index', shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states. if summary_type == 'cls_index' and cls_index is None: we take the last token of the sequence as classification token """ if not isinstance(inputs, (dict, tuple, list)): hidden_states = inputs cls_index = None elif isinstance(inputs, (tuple, list)): hidden_states = inputs[0] cls_index = inputs[1] if len(inputs) > 1 else None assert len(inputs) <= 2, "Too many inputs." else: hidden_states = inputs.get("hidden_states") cls_index = inputs.get("cls_index", None) if self.summary_type == "last": output = hidden_states[:, -1] elif self.summary_type == "first": output = hidden_states[:, 0] elif self.summary_type == "mean": output = tf.reduce_mean(hidden_states, axis=1) elif self.summary_type == "cls_index": hidden_shape = shape_list(hidden_states) # e.g. [batch, num choices, seq length, hidden dims] if cls_index is None: cls_index = tf.fill( hidden_shape[:-2], hidden_shape[-2] - 1 ) # A tensor full of shape [batch] or [batch, num choices] full of sequence length cls_shape = shape_list(cls_index) if len(cls_shape) <= len(hidden_shape) - 2: cls_index = cls_index[..., tf.newaxis] # else: # cls_index = cls_index[..., tf.newaxis] # cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),)) # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states output = tf.gather(hidden_states, cls_index, batch_dims=len(hidden_shape) - 2) output = tf.squeeze( output, axis=len(hidden_shape) - 2 ) # shape of output: (batch, num choices, hidden_size) elif self.summary_type == "attn": raise NotImplementedError if self.has_first_dropout: output = self.first_dropout(output, training=training) if self.has_summary: output = self.summary(output) if self.has_activation: output = self.activation(output) if self.has_last_dropout: output = self.last_dropout(output, training=training) return output def shape_list(x): """Deal with dynamic shape in tensorflow cleanly.""" static = x.shape.as_list() dynamic = tf.shape(x) return [dynamic[i] if s is None else s for i, s in enumerate(static)] def get_initializer(initializer_range=0.02): """Creates a `tf.initializers.truncated_normal` with the given range. Args: initializer_range: float, initializer range for stddev. Returns: TruncatedNormal initializer with stddev = `initializer_range`. """ return tf.keras.initializers.TruncatedNormal(stddev=initializer_range) def cast_bool_to_primitive(bool_variable, default_tensor_to_true=False): """Function arguments can be inserted as boolean tensor and bool variables to cope with keras serialization we need to cast `output_attentions` to correct bool if it is a tensor Args: default_tensor_to_true: bool, if tensor should default to True in case tensor has no numpy attribute """ # if bool variable is tensor and has numpy value if tf.is_tensor(bool_variable): if hasattr(bool_variable, "numpy"): return bool(bool_variable.numpy()) elif default_tensor_to_true: return True # else variable is bool return bool_variable