Source code for transformers.modeling_tf_electra

from dataclasses import dataclass
from typing import Optional, Tuple

import tensorflow as tf

from .activations_tf import get_tf_activation
from .configuration_electra import ElectraConfig
from .file_utils import (
    MULTIPLE_CHOICE_DUMMY_INPUTS,
    ModelOutput,
    add_code_sample_docstrings,
    add_start_docstrings,
    add_start_docstrings_to_callable,
    replace_return_docstrings,
)
from .modeling_tf_outputs import (
    TFBaseModelOutput,
    TFMaskedLMOutput,
    TFMultipleChoiceModelOutput,
    TFQuestionAnsweringModelOutput,
    TFSequenceClassifierOutput,
    TFTokenClassifierOutput,
)
from .modeling_tf_utils import (
    TFMaskedLanguageModelingLoss,
    TFMultipleChoiceLoss,
    TFPreTrainedModel,
    TFQuestionAnsweringLoss,
    TFSequenceClassificationLoss,
    TFSequenceSummary,
    TFTokenClassificationLoss,
    get_initializer,
    keras_serializable,
    shape_list,
)
from .tokenization_utils import BatchEncoding
from .utils import logging


logger = logging.get_logger(__name__)

_CONFIG_FOR_DOC = "ElectraConfig"
_TOKENIZER_FOR_DOC = "ElectraTokenizer"

TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST = [
    "google/electra-small-generator",
    "google/electra-base-generator",
    "google/electra-large-generator",
    "google/electra-small-discriminator",
    "google/electra-base-discriminator",
    "google/electra-large-discriminator",
    # See all ELECTRA models at https://huggingface.co/models?filter=electra
]


# Copied from transformers.modeling_tf_bert.TFBertSelfAttention
class TFElectraSelfAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        if config.hidden_size % config.num_attention_heads != 0:
            raise ValueError(
                "The hidden size (%d) is not a multiple of the number of attention "
                "heads (%d)" % (config.hidden_size, config.num_attention_heads)
            )

        self.num_attention_heads = config.num_attention_heads
        assert config.hidden_size % config.num_attention_heads == 0
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size
        self.query = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
        )
        self.key = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key"
        )
        self.value = tf.keras.layers.Dense(
            self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
        )
        self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)

    def transpose_for_scores(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))

        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
        batch_size = shape_list(hidden_states)[0]
        mixed_query_layer = self.query(hidden_states)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)
        query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
        key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
        value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

        # Take the dot product between "query" and "key" to get the raw attention scores.
        attention_scores = tf.matmul(
            query_layer, key_layer, transpose_b=True
        )  # (batch size, num_heads, seq_len_q, seq_len_k)
        dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype)  # scale attention_scores
        attention_scores = attention_scores / tf.math.sqrt(dk)

        if attention_mask is not None:
            # Apply the attention mask is (precomputed for all layers in TFBertModel call() function)
            attention_scores = attention_scores + attention_mask

        # Normalize the attention scores to probabilities.
        attention_probs = tf.nn.softmax(attention_scores, axis=-1)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.dropout(attention_probs, training=training)

        # Mask heads if we want to
        if head_mask is not None:
            attention_probs = attention_probs * head_mask

        context_layer = tf.matmul(attention_probs, value_layer)
        context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
        context_layer = tf.reshape(
            context_layer, (batch_size, -1, self.all_head_size)
        )  # (batch_size, seq_len_q, all_head_size)
        outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

        return outputs


# Copied from transformers.modeling_tf_bert.TFBertSelfOutput
class TFElectraSelfOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, hidden_states, input_tensor, training=False):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states, training=training)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)

        return hidden_states


# Copied from from transformers.modeling_tf_bert.TFBertAttention with Bert->Electra
class TFElectraAttention(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.self_attention = TFElectraSelfAttention(config, name="self")
        self.dense_output = TFElectraSelfOutput(config, name="output")

    def prune_heads(self, heads):
        raise NotImplementedError

    def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
        self_outputs = self.self_attention(
            input_tensor, attention_mask, head_mask, output_attentions, training=training
        )
        attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
        outputs = (attention_output,) + self_outputs[1:]  # add attentions if we output them

        return outputs


# Copied from transformers.modeling_tf_bert.TFBertIntermediate
class TFElectraIntermediate(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(
            config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )

        if isinstance(config.hidden_act, str):
            self.intermediate_act_fn = get_tf_activation(config.hidden_act)
        else:
            self.intermediate_act_fn = config.hidden_act

    def call(self, hidden_states):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.intermediate_act_fn(hidden_states)

        return hidden_states


# Copied from transformers.modeling_tf_bert.TFBertOutput
class TFElectraOutput(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(
            config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
        )
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def call(self, hidden_states, input_tensor, training=False):
        hidden_states = self.dense(hidden_states)
        hidden_states = self.dropout(hidden_states, training=training)
        hidden_states = self.LayerNorm(hidden_states + input_tensor)

        return hidden_states


# Copied from transformers.modeling_tf_bert.TFBertLayer with Bert->Electra
class TFElectraLayer(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.attention = TFElectraAttention(config, name="attention")
        self.intermediate = TFElectraIntermediate(config, name="intermediate")
        self.bert_output = TFElectraOutput(config, name="output")

    def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
        attention_outputs = self.attention(
            hidden_states, attention_mask, head_mask, output_attentions, training=training
        )
        attention_output = attention_outputs[0]
        intermediate_output = self.intermediate(attention_output)
        layer_output = self.bert_output(intermediate_output, attention_output, training=training)
        outputs = (layer_output,) + attention_outputs[1:]  # add attentions if we output them

        return outputs


# Copied from transformers.modeling_tf_bert.TFBertEncoder with Bert->Electra
class TFElectraEncoder(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.layer = [TFElectraLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]

    def call(
        self,
        hidden_states,
        attention_mask,
        head_mask,
        output_attentions,
        output_hidden_states,
        return_dict,
        training=False,
    ):
        all_hidden_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        for i, layer_module in enumerate(self.layer):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer_outputs = layer_module(
                hidden_states, attention_mask, head_mask[i], output_attentions, training=training
            )
            hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        # Add last layer
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)

        return TFBaseModelOutput(
            last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
        )


# Copied from transformers.modeling_tf_bert.TFBertPooler
class TFElectraPooler(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(
            config.hidden_size,
            kernel_initializer=get_initializer(config.initializer_range),
            activation="tanh",
            name="dense",
        )

    def call(self, hidden_states):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)

        return pooled_output


class TFElectraEmbeddings(tf.keras.layers.Layer):
    """Construct the embeddings from word, position and token_type embeddings."""

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.vocab_size = config.vocab_size
        self.embedding_size = config.embedding_size
        self.initializer_range = config.initializer_range
        self.position_embeddings = tf.keras.layers.Embedding(
            config.max_position_embeddings,
            config.embedding_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="position_embeddings",
        )
        self.token_type_embeddings = tf.keras.layers.Embedding(
            config.type_vocab_size,
            config.embedding_size,
            embeddings_initializer=get_initializer(self.initializer_range),
            name="token_type_embeddings",
        )

        # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
        # any TensorFlow checkpoint file
        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)

    def build(self, input_shape):
        """Build shared word embedding layer """
        with tf.name_scope("word_embeddings"):
            # Create and initialize weights. The random normal initializer was chosen
            # arbitrarily, and works well.
            self.word_embeddings = self.add_weight(
                "weight",
                shape=[self.vocab_size, self.embedding_size],
                initializer=get_initializer(self.initializer_range),
            )

        super().build(input_shape)

    # Copied from transformers.modeling_tf_bert.TFBertEmbeddings.call
    def call(
        self,
        input_ids=None,
        position_ids=None,
        token_type_ids=None,
        inputs_embeds=None,
        mode="embedding",
        training=False,
    ):
        """Get token embeddings of inputs.
        Args:
            inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
            mode: string, a valid value is one of "embedding" and "linear".
        Returns:
            outputs: (1) If mode == "embedding", output embedding tensor, float32 with
                shape [batch_size, length, embedding_size]; (2) mode == "linear", output
                linear tensor, float32 with shape [batch_size, length, vocab_size].
        Raises:
            ValueError: if mode is not valid.

        Shared weights logic adapted from
            https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
        """
        if mode == "embedding":
            return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
        elif mode == "linear":
            return self._linear(input_ids)
        else:
            raise ValueError("mode {} is not valid.".format(mode))

    # Copied from transformers.modeling_tf_bert.TFBertEmbeddings._embedding
    def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
        """Applies embedding based on inputs tensor."""
        assert not (input_ids is None and inputs_embeds is None)

        if input_ids is not None:
            input_shape = shape_list(input_ids)
        else:
            input_shape = shape_list(inputs_embeds)[:-1]

        seq_length = input_shape[1]

        if position_ids is None:
            position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]

        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        if inputs_embeds is None:
            inputs_embeds = tf.gather(self.word_embeddings, input_ids)

        position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
        token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
        embeddings = inputs_embeds + position_embeddings + token_type_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings, training=training)

        return embeddings

    def _linear(self, inputs):
        """Computes logits by running inputs through a linear layer.
        Args:
            inputs: A float32 tensor with shape [batch_size, length, hidden_size]
        Returns:
            float32 tensor with shape [batch_size, length, vocab_size].
        """
        batch_size = shape_list(inputs)[0]
        length = shape_list(inputs)[1]
        x = tf.reshape(inputs, [-1, self.embedding_size])
        logits = tf.matmul(x, self.word_embeddings, transpose_b=True)

        return tf.reshape(logits, [batch_size, length, self.vocab_size])


class TFElectraDiscriminatorPredictions(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.dense = tf.keras.layers.Dense(config.hidden_size, name="dense")
        self.dense_prediction = tf.keras.layers.Dense(1, name="dense_prediction")
        self.config = config

    def call(self, discriminator_hidden_states, training=False):
        hidden_states = self.dense(discriminator_hidden_states)
        hidden_states = get_tf_activation(self.config.hidden_act)(hidden_states)
        logits = tf.squeeze(self.dense_prediction(hidden_states))

        return logits


class TFElectraGeneratorPredictions(tf.keras.layers.Layer):
    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
        self.dense = tf.keras.layers.Dense(config.embedding_size, name="dense")

    def call(self, generator_hidden_states, training=False):
        hidden_states = self.dense(generator_hidden_states)
        hidden_states = get_tf_activation("gelu")(hidden_states)
        hidden_states = self.LayerNorm(hidden_states)

        return hidden_states


class TFElectraPreTrainedModel(TFPreTrainedModel):
    """An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
    """

    config_class = ElectraConfig
    base_model_prefix = "electra"


@keras_serializable
class TFElectraMainLayer(tf.keras.layers.Layer):
    config_class = ElectraConfig

    def __init__(self, config, **kwargs):
        super().__init__(**kwargs)

        self.embeddings = TFElectraEmbeddings(config, name="embeddings")

        if config.embedding_size != config.hidden_size:
            self.embeddings_project = tf.keras.layers.Dense(config.hidden_size, name="embeddings_project")

        self.encoder = TFElectraEncoder(config, name="encoder")
        self.config = config

    def get_input_embeddings(self):
        return self.embeddings

    def set_input_embeddings(self, value):
        self.embeddings.word_embeddings = value
        self.embeddings.vocab_size = value.shape[0]

    def _resize_token_embeddings(self, new_num_tokens):
        raise NotImplementedError

    def _prune_heads(self, heads_to_prune):
        """Prunes heads of the model.
        heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        See base class PreTrainedModel
        """
        raise NotImplementedError

    def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
        if attention_mask is None:
            attention_mask = tf.fill(input_shape, 1)

        # We create a 3D attention mask from a 2D tensor mask.
        # Sizes are [batch_size, 1, 1, to_seq_length]
        # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
        # this attention mask is more simple than the triangular masking of causal attention
        # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
        extended_attention_mask = attention_mask[:, tf.newaxis, tf.newaxis, :]

        # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # masked positions, this operation will create a tensor which is 0.0 for
        # positions we want to attend and -10000.0 for masked positions.
        # Since we are adding it to the raw scores before the softmax, this is
        # effectively the same as removing these entirely.
        extended_attention_mask = tf.cast(extended_attention_mask, dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        return extended_attention_mask

    def get_head_mask(self, head_mask):
        if head_mask is not None:
            raise NotImplementedError
        else:
            head_mask = [None] * self.config.num_hidden_layers

        return head_mask

    def call(
        self,
        inputs,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        training=False,
    ):
        if isinstance(inputs, (tuple, list)):
            input_ids = inputs[0]
            attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
            token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids
            position_ids = inputs[3] if len(inputs) > 3 else position_ids
            head_mask = inputs[4] if len(inputs) > 4 else head_mask
            inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds
            output_attentions = inputs[6] if len(inputs) > 6 else output_attentions
            output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states
            return_dict = inputs[8] if len(inputs) > 8 else return_dict
            assert len(inputs) <= 9, "Too many inputs."
        elif isinstance(inputs, (dict, BatchEncoding)):
            input_ids = inputs.get("input_ids")
            attention_mask = inputs.get("attention_mask", attention_mask)
            token_type_ids = inputs.get("token_type_ids", token_type_ids)
            position_ids = inputs.get("position_ids", position_ids)
            head_mask = inputs.get("head_mask", head_mask)
            inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
            output_attentions = inputs.get("output_attentions", output_attentions)
            output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
            return_dict = inputs.get("return_dict", return_dict)
            assert len(inputs) <= 9, "Too many inputs."
        else:
            input_ids = inputs

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = shape_list(input_ids)
        elif inputs_embeds is not None:
            input_shape = shape_list(inputs_embeds)[:-1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if attention_mask is None:
            attention_mask = tf.fill(input_shape, 1)

        if token_type_ids is None:
            token_type_ids = tf.fill(input_shape, 0)

        hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
        extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
        head_mask = self.get_head_mask(head_mask)

        if hasattr(self, "embeddings_project"):
            hidden_states = self.embeddings_project(hidden_states, training=training)

        hidden_states = self.encoder(
            hidden_states,
            extended_attention_mask,
            head_mask,
            output_attentions,
            output_hidden_states,
            return_dict,
            training=training,
        )

        return hidden_states


[docs]@dataclass class TFElectraForPreTrainingOutput(ModelOutput): """ Output type of :class:`~transformers.TFElectraForPreTrainingModel`. Args: loss (`optional`, returned when ``labels`` is provided, ``tf.Tensor`` of shape :obj:`(1,)`): Total loss of the ELECTRA objective. logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): Prediction scores of the head (scores for each token before SoftMax). hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer) of shape :obj:`(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): Tuple of :obj:`tf.Tensor` (one for each layer) of shape :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. """ logits: tf.Tensor = None hidden_states: Optional[Tuple[tf.Tensor]] = None attentions: Optional[Tuple[tf.Tensor]] = None
ELECTRA_START_DOCSTRING = r""" This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and behavior. .. note:: TF 2.0 models accepts two formats as inputs: - having all inputs as keyword arguments (like PyTorch models), or - having all inputs as a list, tuple or dict in the first positional arguments. This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all the tensors in the first argument of the model call function: :obj:`model(inputs)`. If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument : - a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)` - a list of varying length with one or several input Tensors IN THE ORDER given in the docstring: :obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])` - a dictionary with one or several input Tensors associated to the input names given in the docstring: :obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})` Parameters: config (:class:`~transformers.ElectraConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights. """ ELECTRA_INPUTS_DOCSTRING = r""" Args: input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`): Indices of input sequence tokens in the vocabulary. Indices can be obtained using :class:`~transformers.ElectraTokenizer`. See :func:`transformers.PreTrainedTokenizer.__call__` and :func:`transformers.PreTrainedTokenizer.encode` for details. `What are input IDs? <../glossary.html#input-ids>`__ attention_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`): Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: - 1 for tokens that are **not masked**, - 0 for tokens that are **maked**. `What are attention masks? <../glossary.html#attention-mask>`__ position_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`({0})`, `optional`): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0, config.max_position_embeddings - 1]``. `What are position IDs? <../glossary.html#position-ids>`__ head_mask (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`): Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``: - 1 indicates the head is **not masked**, - 0 indicates the head is **masked**. inputs_embeds (:obj:`tf.Tensor` of shape :obj:`({0}, hidden_size)`, `optional`): Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors than the model's internal embedding lookup matrix. output_attentions (:obj:`bool`, `optional`): Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`): Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for more detail. return_dict (:obj:`bool`, `optional`): Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. training (:obj:`bool`, `optional`, defaults to :obj:`False`): Whether or not to use the model in training mode (some modules like dropout modules have different behaviors between training and evaluation). """
[docs]@add_start_docstrings( "The bare Electra Model transformer outputting raw hidden-states without any specific head on top. Identical to " "the BERT model except that it uses an additional linear layer between the embedding layer and the encoder if the " "hidden size and embedding size are different." "" "Both the generator and discriminator checkpoints may be loaded into this model.", ELECTRA_START_DOCSTRING, ) class TFElectraModel(TFElectraPreTrainedModel): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.electra = TFElectraMainLayer(config, name="electra")
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", output_type=TFBaseModelOutput, config_class=_CONFIG_FOR_DOC, ) def call(self, inputs, **kwargs): outputs = self.electra(inputs, **kwargs) return outputs
[docs]@add_start_docstrings( """Electra model with a binary classification head on top as used during pre-training for identifying generated tokens. Even though both the discriminator and generator may be loaded into this model, the discriminator is the only model of the two to have the correct classification head to be used for this model.""", ELECTRA_START_DOCSTRING, ) class TFElectraForPreTraining(TFElectraPreTrainedModel): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.electra = TFElectraMainLayer(config, name="electra") self.discriminator_predictions = TFElectraDiscriminatorPredictions(config, name="discriminator_predictions")
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC) def call( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, training=False, ): r""" Returns: Examples:: >>> import tensorflow as tf >>> from transformers import ElectraTokenizer, TFElectraForPreTraining >>> tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator') >>> model = TFElectraForPreTraining.from_pretrained('google/electra-small-discriminator') >>> input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1 >>> outputs = model(input_ids) >>> scores = outputs[0] """ return_dict = return_dict if return_dict is not None else self.electra.config.return_dict discriminator_hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict=return_dict, training=training, ) discriminator_sequence_output = discriminator_hidden_states[0] logits = self.discriminator_predictions(discriminator_sequence_output) if not return_dict: return (logits,) + discriminator_hidden_states[1:] return TFElectraForPreTrainingOutput( logits=logits, hidden_states=discriminator_hidden_states.hidden_states, attentions=discriminator_hidden_states.attentions, )
class TFElectraMaskedLMHead(tf.keras.layers.Layer): def __init__(self, config, input_embeddings, **kwargs): super().__init__(**kwargs) self.vocab_size = config.vocab_size self.input_embeddings = input_embeddings def build(self, input_shape): self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") super().build(input_shape) def call(self, hidden_states, training=False): hidden_states = self.input_embeddings(hidden_states, mode="linear") hidden_states = hidden_states + self.bias return hidden_states
[docs]@add_start_docstrings( """Electra model with a language modeling head on top. Even though both the discriminator and generator may be loaded into this model, the generator is the only model of the two to have been trained for the masked language modeling task.""", ELECTRA_START_DOCSTRING, ) class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLoss): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.vocab_size = config.vocab_size self.electra = TFElectraMainLayer(config, name="electra") self.generator_predictions = TFElectraGeneratorPredictions(config, name="generator_predictions") if isinstance(config.hidden_act, str): self.activation = get_tf_activation(config.hidden_act) else: self.activation = config.hidden_act self.generator_lm_head = TFElectraMaskedLMHead(config, self.electra.embeddings, name="generator_lm_head") def get_output_embeddings(self): return self.generator_lm_head
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-generator", output_type=TFMaskedLMOutput, config_class=_CONFIG_FOR_DOC, ) def call( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, labels=None, training=False, ): r""" labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` """ return_dict = return_dict if return_dict is not None else self.electra.config.return_dict if isinstance(input_ids, (tuple, list)): labels = input_ids[9] if len(input_ids) > 9 else labels if len(input_ids) > 9: input_ids = input_ids[:9] elif isinstance(input_ids, (dict, BatchEncoding)): labels = input_ids.pop("labels", labels) generator_hidden_states = self.electra( input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, ) generator_sequence_output = generator_hidden_states[0] prediction_scores = self.generator_predictions(generator_sequence_output, training=training) prediction_scores = self.generator_lm_head(prediction_scores, training=training) loss = None if labels is None else self.compute_loss(labels, prediction_scores) if not return_dict: output = (prediction_scores,) + generator_hidden_states[1:] return ((loss,) + output) if loss is not None else output return TFMaskedLMOutput( loss=loss, logits=prediction_scores, hidden_states=generator_hidden_states.hidden_states, attentions=generator_hidden_states.attentions, )
class TFElectraClassificationHead(tf.keras.layers.Layer): """Head for sentence-level classification tasks.""" def __init__(self, config, **kwargs): super().__init__(**kwargs) self.dense = tf.keras.layers.Dense( config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" ) self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.out_proj = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" ) def call(self, inputs, **kwargs): x = inputs[:, 0, :] # take <s> token (equiv. to [CLS]) x = self.dropout(x) x = self.dense(x) x = get_tf_activation("gelu")(x) # although BERT uses tanh here, it seems Electra authors used gelu here x = self.dropout(x) x = self.out_proj(x) return x
[docs]@add_start_docstrings( """ELECTRA Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, ELECTRA_START_DOCSTRING, ) class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceClassificationLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels self.electra = TFElectraMainLayer(config, name="electra") self.classifier = TFElectraClassificationHead(config, name="classifier")
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", output_type=TFSequenceClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def call( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, labels=None, training=False, ): r""" labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.electra.config.return_dict if isinstance(input_ids, (tuple, list)): labels = input_ids[9] if len(input_ids) > 9 else labels if len(input_ids) > 9: input_ids = input_ids[:9] elif isinstance(input_ids, (dict, BatchEncoding)): labels = input_ids.pop("labels", labels) outputs = self.electra( input_ids, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict=return_dict, training=training, ) logits = self.classifier(outputs[0]) loss = None if labels is None else self.compute_loss(labels, logits) if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TFSequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]@add_start_docstrings( """ELECTRA Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, ELECTRA_START_DOCSTRING, ) class TFElectraForMultipleChoice(TFElectraPreTrainedModel, TFMultipleChoiceLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.electra = TFElectraMainLayer(config, name="electra") self.sequence_summary = TFSequenceSummary( config, initializer_range=config.initializer_range, name="sequence_summary" ) self.classifier = tf.keras.layers.Dense( 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" ) @property def dummy_inputs(self): """Dummy inputs to build the network. Returns: tf.Tensor with dummy inputs """ return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS)}
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", output_type=TFMultipleChoiceModelOutput, config_class=_CONFIG_FOR_DOC, ) def call( self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, labels=None, training=False, ): r""" labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): Labels for computing the multiple choice classification loss. Indices should be in ``[0, ..., num_choices]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See :obj:`input_ids` above) """ if isinstance(inputs, (tuple, list)): input_ids = inputs[0] attention_mask = inputs[1] if len(inputs) > 1 else attention_mask token_type_ids = inputs[2] if len(inputs) > 2 else token_type_ids position_ids = inputs[3] if len(inputs) > 3 else position_ids head_mask = inputs[4] if len(inputs) > 4 else head_mask inputs_embeds = inputs[5] if len(inputs) > 5 else inputs_embeds output_attentions = inputs[6] if len(inputs) > 6 else output_attentions output_hidden_states = inputs[7] if len(inputs) > 7 else output_hidden_states return_dict = inputs[8] if len(inputs) > 8 else return_dict labels = inputs[9] if len(inputs) > 9 else labels assert len(inputs) <= 10, "Too many inputs." elif isinstance(inputs, (dict, BatchEncoding)): input_ids = inputs.get("input_ids") attention_mask = inputs.get("attention_mask", attention_mask) token_type_ids = inputs.get("token_type_ids", token_type_ids) position_ids = inputs.get("position_ids", position_ids) head_mask = inputs.get("head_mask", head_mask) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) output_attentions = inputs.get("output_attentions", output_attentions) output_hidden_states = inputs.get("output_hidden_states", output_hidden_states) return_dict = inputs.get("return_dict", return_dict) labels = inputs.get("labels", labels) assert len(inputs) <= 10, "Too many inputs." else: input_ids = inputs return_dict = return_dict if return_dict is not None else self.electra.config.return_dict if input_ids is not None: num_choices = shape_list(input_ids)[1] seq_length = shape_list(input_ids)[2] else: num_choices = shape_list(inputs_embeds)[1] seq_length = shape_list(inputs_embeds)[2] flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None flat_inputs_embeds = ( tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3])) if inputs_embeds is not None else None ) outputs = self.electra( flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, head_mask, flat_inputs_embeds, output_attentions, output_hidden_states, return_dict=return_dict, training=training, ) logits = self.sequence_summary(outputs[0]) logits = self.classifier(logits) reshaped_logits = tf.reshape(logits, (-1, num_choices)) loss = None if labels is None else self.compute_loss(labels, reshaped_logits) if not return_dict: output = (reshaped_logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return TFMultipleChoiceModelOutput( loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, )
[docs]@add_start_docstrings( """Electra model with a token classification head on top. Both the discriminator and generator may be loaded into this model.""", ELECTRA_START_DOCSTRING, ) class TFElectraForTokenClassification(TFElectraPreTrainedModel, TFTokenClassificationLoss): def __init__(self, config, **kwargs): super().__init__(config, **kwargs) self.electra = TFElectraMainLayer(config, name="electra") self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) self.classifier = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" )
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", output_type=TFTokenClassifierOutput, config_class=_CONFIG_FOR_DOC, ) def call( self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, labels=None, training=False, ): r""" labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels - 1]``. """ return_dict = return_dict if return_dict is not None else self.electra.config.return_dict if isinstance(inputs, (tuple, list)): labels = inputs[9] if len(inputs) > 9 else labels if len(inputs) > 9: inputs = inputs[:9] elif isinstance(inputs, (dict, BatchEncoding)): labels = inputs.pop("labels", labels) discriminator_hidden_states = self.electra( inputs, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict=return_dict, training=training, ) discriminator_sequence_output = discriminator_hidden_states[0] discriminator_sequence_output = self.dropout(discriminator_sequence_output) logits = self.classifier(discriminator_sequence_output) loss = None if labels is None else self.compute_loss(labels, logits) if not return_dict: output = (logits,) + discriminator_hidden_states[1:] return ((loss,) + output) if loss is not None else output return TFTokenClassifierOutput( loss=loss, logits=logits, hidden_states=discriminator_hidden_states.hidden_states, attentions=discriminator_hidden_states.attentions, )
[docs]@add_start_docstrings( """Electra Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, ELECTRA_START_DOCSTRING, ) class TFElectraForQuestionAnswering(TFElectraPreTrainedModel, TFQuestionAnsweringLoss): def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.num_labels = config.num_labels self.electra = TFElectraMainLayer(config, name="electra") self.qa_outputs = tf.keras.layers.Dense( config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" )
[docs] @add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_code_sample_docstrings( tokenizer_class=_TOKENIZER_FOR_DOC, checkpoint="google/electra-small-discriminator", output_type=TFQuestionAnsweringModelOutput, config_class=_CONFIG_FOR_DOC, ) def call( self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, return_dict=None, start_positions=None, end_positions=None, training=False, ): r""" start_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the start of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. end_positions (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`): Labels for position (index) of the end of the labelled span for computing the token classification loss. Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the sequence are not taken into account for computing the loss. """ return_dict = return_dict if return_dict is not None else self.electra.config.return_dict if isinstance(inputs, (tuple, list)): start_positions = inputs[9] if len(inputs) > 9 else start_positions end_positions = inputs[10] if len(inputs) > 10 else end_positions if len(inputs) > 9: inputs = inputs[:9] elif isinstance(inputs, (dict, BatchEncoding)): start_positions = inputs.pop("start_positions", start_positions) end_positions = inputs.pop("end_positions", start_positions) discriminator_hidden_states = self.electra( inputs, attention_mask, token_type_ids, position_ids, head_mask, inputs_embeds, output_attentions, output_hidden_states, return_dict=return_dict, training=training, ) discriminator_sequence_output = discriminator_hidden_states[0] logits = self.qa_outputs(discriminator_sequence_output) start_logits, end_logits = tf.split(logits, 2, axis=-1) start_logits = tf.squeeze(start_logits, axis=-1) end_logits = tf.squeeze(end_logits, axis=-1) loss = None if start_positions is not None and end_positions is not None: labels = {"start_position": start_positions} labels["end_position"] = end_positions loss = self.compute_loss(labels, (start_logits, end_logits)) if not return_dict: output = ( start_logits, end_logits, ) + discriminator_hidden_states[1:] return ((loss,) + output) if loss is not None else output return TFQuestionAnsweringModelOutput( loss=loss, start_logits=start_logits, end_logits=end_logits, hidden_states=discriminator_hidden_states.hidden_states, attentions=discriminator_hidden_states.attentions, )