# coding=utf-8
# Copyright 2020 The Facebook AI Research Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TF BART model, ported from the fairseq repo."""
import math
import random
import warnings
from typing import Dict, Optional, Tuple
import numpy as np
import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Dense, Layer, LayerNormalization
from .activations_tf import ACT2FN
from .configuration_bart import BartConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
from .modeling_tf_outputs import TFBaseModelOutput, TFBaseModelOutputWithPast, TFSeq2SeqLMOutput, TFSeq2SeqModelOutput
# Public API
from .modeling_tf_utils import (
DUMMY_INPUTS,
TFPreTrainedModel,
TFSharedEmbeddings,
TFWrappedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
)
from .tokenization_utils_base import BatchEncoding
from .utils import logging
_CONFIG_FOR_DOC = "BartConfig"
BART_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.TFPreTrainedModel`. Check the superclass documentation for the
generic methods the library implements for all its model (such as downloading or saving, resizing the input
embeddings, pruning heads etc.)
This model is also a `tf.keras.Model <https://www.tensorflow.org/api_docs/python/tf/keras/Model>`__ subclass. Use
it as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage
and behavior.
.. note::
TF 2.0 models accepts two formats as inputs:
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is useful when using :meth:`tf.keras.Model.fit` method which currently requires having all
the tensors in the first argument of the model call function: :obj:`model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in
the first positional argument :
- a single Tensor with :obj:`input_ids` only and nothing else: :obj:`model(inputs_ids)`
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
:obj:`model([input_ids, attention_mask])` or :obj:`model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associated to the input names given in the docstring:
:obj:`model({"input_ids": input_ids, "token_type_ids": token_type_ids})`
Args:
config (:class:`~transformers.BartConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the :meth:`~transformers.TFPreTrainedModel.from_pretrained` method to load the
model weights.
"""
BART_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`tf.Tensor` of shape :obj:`({0})`):
Indices of input sequence tokens in the vocabulary.
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`tf.Tensor` of shape :obj:`({0})`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for translation and summarization training. By default, the model will create this tensor by
shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
will be made by default and ignore pad tokens. It is not recommended to set this for most use cases.
encoder_outputs (:obj:`tf.FloatTensor`, `optional`):
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
past_key_values (:obj:`Tuple[Dict[str: tf.Tensor]]` of length :obj:`config.n_layers`)
contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`). Set to :obj:`False` during training, :obj:`True` during generation
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.TFModelOutput` instead of a plain tuple.
training (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).
"""
LARGE_NEGATIVE = -1e8
logger = logging.get_logger(__name__)
def create_position_ids_from_input_ids(input_ids, padding_idx):
"""
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
are ignored. This is modified from fairseq's `utils.make_positions`.
"""
mask = input_ids.ne(padding_idx).int()
incremental_indices = tf.cumsum(mask, axis=1).type_as(mask) * mask
return incremental_indices.long() + padding_idx
def causal_attention_mask(nd, ns, dtype):
"""
1's in the lower triangle, counting from the lower right corner. Same as tf.matrix_band_part(tf.ones([nd, ns]), -1,
ns-nd), but doesn't produce garbage on TPUs.
"""
i = tf.range(nd)[:, None]
j = tf.range(ns)
m = i < j - ns + nd
return tf.cast(m, dtype) * LARGE_NEGATIVE
def invert_mask(attention_mask: tf.Tensor):
"""Turns 1->0, 0->1, False->True, True-> False"""
tf.debugging.assert_rank(attention_mask, 2)
attention_mask = tf.cast(attention_mask, tf.bool)
ret = tf.math.logical_not(attention_mask) # dtype is tf.bool
return ret
class TFPretrainedBartModel(TFPreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
@property
def dummy_inputs(self):
pad_token = 1
input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
decoder_input_ids = tf.cast(tf.constant(DUMMY_INPUTS), tf.int32)
dummy_inputs = {
"decoder_input_ids": decoder_input_ids,
"attention_mask": tf.math.not_equal(input_ids, pad_token),
"input_ids": input_ids,
}
return dummy_inputs
def _shift_right(self, input_ids):
# Should maybe be decoder_start_token_id. Change for torch and TF in one PR
position_0_id = self.config.eos_token_id
pad_token_id = self.config.pad_token_id
shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), position_0_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
)
# "Verify that `labels` has only positive values and -100"
assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids
# Helper Functions, mostly for making masks
def make_padding_mask(input_ids, padding_idx=1):
"""True for pad tokens"""
padding_mask = tf.math.equal(input_ids, padding_idx) # bool tensor
return padding_mask
# Helper Modules
PAST_KV_DEPRECATION_WARNING = (
"The `past_key_value_states` argument is deprecated and will be removed in a future "
"version, use `past_key_values` instead."
)
class TFEncoderLayer(Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFAttention(
self.embed_dim, config.encoder_attention_heads, dropout=config.attention_dropout, name="self_attn"
)
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.fc1 = Dense(config.encoder_ffn_dim, name="fc1")
self.fc2 = Dense(self.embed_dim, name="fc2")
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(self, x, encoder_padding_mask, training=False):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, src_len)` where padding elements are indicated by ``1``.
for t_tgt, t_src is excluded (or masked out), =0 means it is
included in attention
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
residual = x
if self.normalize_before:
x = self.self_attn_layer_norm(x)
x, self_attn_weights = self.self_attn(query=x, key=x, key_padding_mask=encoder_padding_mask)
assert shape_list(x) == shape_list(
residual
), f"Self attn modified the shape of query {shape_list(residual)} to {shape_list(x)}"
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = tf.nn.dropout(x, rate=self.self.activation_dropout if training else 0)
x = self.fc2(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return x, self_attn_weights
class TFBartEncoder(Layer):
# config_class = BartConfig
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
:class:`TFEncoderLayer`.
Args:
config: BartConfig
"""
def __init__(self, config: BartConfig, embed_tokens: TFSharedEmbeddings, **kwargs):
super().__init__(**kwargs)
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.padding_idx = config.pad_token_id
self.max_source_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens
if config.static_position_embeddings:
self.embed_positions = TFSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
else:
self.embed_positions = TFLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.extra_pos_embeddings,
name="embed_positions",
)
self.layers = [TFEncoderLayer(config, name=f"layers.{i}") for i in range(config.encoder_layers)]
self.layernorm_embedding = (
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
)
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
self.return_dict = config.return_dict
def call(
self,
input_ids=None,
attention_mask=None,
output_attentions=False,
output_hidden_states=False,
return_dict=None,
training=False,
):
"""
Args:
input_ids (Tensor): tokens in the source language of shape
`(batch, src_len)`
attention_mask (Tensor): indicating which indices are padding tokens
Returns:
namedtuple:
- **x** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate hidden states of shape `(src_len, batch,
embed_dim)`. Only populated if *output_hidden_states* is True.
- **all_attentions** (List[Tensor]): Attention weights for each layer.
During training might not be of length n_layers because of layer dropout.
"""
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
return_dict = return_dict if return_dict is not None else self.return_dict
# check attention mask and invert
if attention_mask is not None:
assert (
attention_mask._rank() == 2
), f"expected attention_mask._rank() to be a 2D tensor got {attention_mask._rank()}"
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
attention_mask = (1.0 - attention_mask) * LARGE_NEGATIVE
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_ids)
x = inputs_embeds + embed_pos
x = self.layernorm_embedding(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
# B x T x C -> T x B x C
x = tf.transpose(x, perm=[1, 0, 2])
encoder_states = [] if output_hidden_states else None
all_attentions = () if output_attentions else None
# encoder layers
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states.append(x)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if training and (dropout_probability < self.layerdrop): # skip the layer
attn = None
else:
x, attn = encoder_layer(x, attention_mask)
if output_attentions:
all_attentions += (attn,)
if self.layer_norm:
x = self.layer_norm(x)
if output_hidden_states:
encoder_states.append(x)
encoder_states = [tf.transpose(hidden_state, perm=(1, 0, 2)) for hidden_state in encoder_states]
x = tf.transpose(x, perm=(1, 0, 2))
if not return_dict:
return tuple(v for v in [x, encoder_states, all_attentions] if v is not None)
return TFBaseModelOutput(last_hidden_state=x, hidden_states=encoder_states, attentions=all_attentions)
class TFDecoderLayer(Layer):
def __init__(self, config: BartConfig, **kwargs):
super().__init__(**kwargs)
self.embed_dim = config.d_model
self.self_attn = TFAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
name="self_attn",
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
self.activation_dropout = config.activation_dropout
self.normalize_before = config.normalize_before
self.self_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="self_attn_layer_norm")
self.encoder_attn = TFAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
encoder_decoder_attention=True,
name="encoder_attn",
)
self.encoder_attn_layer_norm = LayerNormalization(epsilon=1e-5, name="encoder_attn_layer_norm")
self.fc1 = Dense(config.decoder_ffn_dim, name="fc1")
self.fc2 = Dense(self.embed_dim, name="fc2")
self.final_layer_norm = LayerNormalization(epsilon=1e-5, name="final_layer_norm")
def call(
self,
x,
encoder_hidden_states: tf.Tensor,
encoder_attn_mask=None,
layer_state=None,
causal_mask=None,
decoder_padding_mask=None,
training=False,
) -> Tuple[tf.Tensor, tf.Tensor, Dict[str, tf.Tensor]]:
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_attn_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding elements are indicated by ``1``.
need_attn_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
Tuple containing, encoded output of shape `(seq_len, batch, embed_dim)`, self_attn_weights, layer_state
"""
residual = x # Make a copy of the input tensor to add later.
if layer_state is None:
layer_state = {}
if self.normalize_before:
x = self.self_attn_layer_norm(x)
# next line mutates layer state and we need a copy of it
x, self_attn_weights = self.self_attn(
query=x,
key=x,
layer_state=layer_state,
attn_mask=causal_mask,
key_padding_mask=decoder_padding_mask,
)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = residual + x
if not self.normalize_before:
x = self.self_attn_layer_norm(x)
# Cross-Attention Block
residual = x
if self.normalize_before:
x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn(
query=x,
key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state
)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = residual + x
if not self.normalize_before:
x = self.encoder_attn_layer_norm(x)
# Fully Connected
residual = x
if self.normalize_before:
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = tf.nn.dropout(x, rate=self.activation_dropout if training else 0)
x = self.fc2(x)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
x = residual + x
if not self.normalize_before:
x = self.final_layer_norm(x)
return (
x,
self_attn_weights,
layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding
class TFBartDecoder(Layer):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`TFDecoderLayer`
Args:
config: BartConfig
embed_tokens: output embedding
"""
def __init__(self, config: BartConfig, embed_tokens, **kwargs):
super().__init__(**kwargs)
self.layerdrop = config.decoder_layerdrop
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_position_embeddings
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
if config.static_position_embeddings:
self.embed_positions = TFSinusoidalPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
name="embed_positions",
)
else:
self.embed_positions = TFLearnedPositionalEmbedding(
config.max_position_embeddings,
config.d_model,
self.padding_idx,
config.extra_pos_embeddings,
name="embed_positions",
)
self.layers = [TFDecoderLayer(config, name=f"layers.{i}") for i in range(config.decoder_layers)]
self.layernorm_embedding = (
LayerNormalization(epsilon=1e-5, name="layernorm_embedding") if config.normalize_embedding else Layer()
)
self.layer_norm = LayerNormalization(epsilon=1e-5, name="layer_norm") if config.add_final_layer_norm else None
self.dropout = config.dropout
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.use_cache = config.use_cache
self.do_blenderbot_90_layernorm = config.do_blenderbot_90_layernorm
def call(
self,
input_ids,
encoder_hidden_states,
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
decoder_cached_states=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
return_dict=None,
training=False,
):
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
if use_cache:
assert not training, "Training + use cache are incompatible"
# check attention mask and invert
use_cache = cast_bool_to_primitive(use_cache)
if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask)
# embed positions
positions = self.embed_positions(input_ids, use_cache=use_cache)
if use_cache:
input_ids = input_ids[:, -1:]
positions = positions[:, -1:]
x = self.embed_tokens(input_ids) * self.embed_scale
if self.do_blenderbot_90_layernorm:
x = self.layernorm_embedding(x) + positions
else:
x = self.layernorm_embedding(x + positions)
x = tf.nn.dropout(x, rate=self.dropout if training else 0)
# Convert to Bart output format: (BS, seq_len, model_dim) -> (seq_len, BS, model_dim)
x = tf.transpose(x, perm=(1, 0, 2))
assert len(shape_list(encoder_hidden_states)) == 3, "encoder_hidden_states must be a 3D tensor"
encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2))
# decoder layers
all_hidden_states = ()
all_self_attns = ()
next_decoder_cache = []
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (x,)
dropout_probability = random.uniform(0, 1)
if training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask,
decoder_padding_mask=decoder_padding_mask,
layer_state=layer_state,
causal_mask=decoder_causal_mask,
)
if use_cache:
next_decoder_cache.append(layer_past.copy())
if output_attentions:
all_self_attns += (layer_self_attn,)
if self.layer_norm is not None: # same as if config.add_final_layer_norm
x = self.layer_norm(x)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states:
all_hidden_states += (x,)
# T x B x C -> B x T x C
all_hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in all_hidden_states)
else:
all_hidden_states = None
all_self_attns = list(all_self_attns) if output_attentions else None
x = tf.transpose(x, perm=(1, 0, 2))
encoder_hidden_states = tf.transpose(encoder_hidden_states, perm=(1, 0, 2)) # could maybe be avoided.
next_cache = (encoder_hidden_states, next_decoder_cache) if use_cache else None
if not return_dict:
return x, next_cache, all_hidden_states, all_self_attns
else:
return TFBaseModelOutputWithPast(
last_hidden_state=x,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
def _reorder_buffer(attn_cache, new_order):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = tf.gather(input_buffer_k, new_order, axis=0)
return attn_cache
class TFAttention(Layer):
"""Multi-headed attention from "Attention Is All You Need"""
def __init__(
self,
embed_dim,
num_heads,
dropout=0.0,
bias=True,
encoder_decoder_attention=False, # otherwise self_attention
**kwargs,
):
super().__init__(**kwargs)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
self.scaling = self.head_dim ** -0.5
self.encoder_decoder_attention = encoder_decoder_attention
self.k_proj = Dense(embed_dim, use_bias=bias, name="k_proj")
self.q_proj = Dense(embed_dim, use_bias=bias, name="q_proj")
self.v_proj = Dense(embed_dim, use_bias=bias, name="v_proj")
self.out_proj = Dense(embed_dim, use_bias=bias, name="out_proj")
self.cache_key = "encoder_decoder" if self.encoder_decoder_attention else "self"
def _shape(self, tensor: tf.Tensor, dim_0, bsz) -> tf.Tensor:
reshaped_T_B_D = tf.reshape(tensor, (dim_0, bsz * self.num_heads, self.head_dim))
return tf.transpose(reshaped_T_B_D, perm=(1, 0, 2))
def call(
self,
query: tf.Tensor,
key: tf.Tensor,
key_padding_mask: Optional[tf.Tensor] = None,
layer_state: Optional[Dict[str, tf.Tensor]] = None,
attn_mask: Optional[Tensor] = None,
training=False,
) -> Tuple[Tensor, Optional[Tensor]]:
"""
Input shape: Time(SeqLen) x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the attention from looking forward in time
(default: None).
"""
static_kv = self.encoder_decoder_attention # value=key=encoder_hidden_states,
tgt_len, bsz, embed_dim = shape_list(query)
assert (
embed_dim == self.embed_dim
), f"query must be shaped {(tgt_len, bsz, self.embed_dim)} got {shape_list(query)}"
# get here for encoder decoder cause of static_kv
if layer_state is not None: # get the last k and v for reuse
saved_state = layer_state.get(self.cache_key, {})
if "prev_key" in saved_state:
# previous time steps are cached - no need to recompute key and value if they are static
if static_kv:
key = None
else:
# this branch is hit by encoder
saved_state = None
# Project query key values using weights q_proj, k_proj, v_proj
q = self.q_proj(query) * self.scaling
if static_kv and key is None: # cross-attention with cache
k = v = None
elif static_kv and key is not None: # cross-attention no prev_key found in cache
k = self.k_proj(key)
v = self.v_proj(key)
else: # self-attention
k = self.k_proj(query)
v = self.v_proj(query)
# Reshape
q = self._shape(q, tgt_len, bsz)
if k is not None:
k = self._shape(k, -1, bsz)
v = self._shape(v, -1, bsz)
if saved_state: # read from cache
k, v = self._concat_saved_state(k, v, saved_state, static_kv, bsz)
if layer_state is not None: # Write to cache every decoder call
cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache
layer_state[self.cache_key] = dict(
prev_key=tf.reshape(k, cached_shape), prev_value=tf.reshape(v, cached_shape)
)
# Compute multi-headed attention
src_len = shape_list(k)[1]
attn_weights = tf.matmul(q, k, transpose_b=True) # shape (bsz * self.num_heads, tgt_len, src_len)
if attn_mask is not None:
assert attn_mask.dtype == tf.float32, f"expected dtype tf.float32 got {attn_mask.dtype}"
attn_weights = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len)) + attn_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
if key_padding_mask is not None: # don't attend to padding symbols
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
if key_padding_mask.dtype == tf.bool:
key_padding_mask = tf.cast(key_padding_mask, attn_weights.dtype) * -1e9
extended_mask = tf.expand_dims(tf.expand_dims(key_padding_mask, 1), 2)
attn_weights = attn_weights + extended_mask
attn_weights = tf.reshape(attn_weights, (bsz * self.num_heads, tgt_len, src_len))
attn_weights = tf.nn.softmax(attn_weights, axis=-1)
attn_probs = tf.nn.dropout(attn_weights, rate=self.dropout if training else 0.0)
attn_output = tf.matmul(attn_probs, v) # shape: (bsz * self.num_heads, tgt_len, self.head_dim)
attn_output = tf.transpose(attn_output, perm=(1, 0, 2))
attn_output = tf.reshape(attn_output, (tgt_len, bsz, embed_dim))
attn_output = self.out_proj(attn_output)
attn_weights: tf.Tensor = tf.reshape(attn_weights, (bsz, self.num_heads, tgt_len, src_len))
return attn_output, attn_weights
def _concat_saved_state(self, k, v, saved_state, static_kv, bsz) -> Tuple[tf.Tensor]:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
prev_key = tf.reshape(saved_state["prev_key"], (bsz * self.num_heads, -1, self.head_dim))
k = prev_key if static_kv else tf.concat([prev_key, k], axis=1)
prev_value = tf.reshape(saved_state["prev_value"], (bsz * self.num_heads, -1, self.head_dim))
v = prev_value if static_kv else tf.concat([prev_value, v], axis=1)
return k, v
class TFLearnedPositionalEmbedding(TFSharedEmbeddings):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
the forward function.
"""
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset, **kwargs):
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models dont have this hack
self.offset = offset
assert padding_idx is not None, "padding_idx cannot be None"
num_embeddings += offset
super().__init__(num_embeddings, embedding_dim, **kwargs)
def call(self, input_ids: tf.Tensor, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = shape_list(input_ids)[:2]
if use_cache:
positions = tf.fill((1, 1), seq_len - 1)
else:
# starts at 0, ends at 1-seq_len
positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range")
return super().call(positions + self.offset) # super object is not callable for some reason
class TFSinusoidalPositionalEmbedding(tf.keras.layers.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
def __init__(self, num_positions, embedding_dim, **kwargs):
if embedding_dim % 2 != 0:
raise NotImplementedError(f"odd embedding_dim {embedding_dim} not supported")
super().__init__(
num_positions,
embedding_dim,
**kwargs,
)
def build(self, input_shape):
"""
Build shared token embedding layer Shared weights logic adapted from
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
"""
super().build(input_shape) # Instantiates self.weight so it can be loaded
weight: np.ndarray = self._init_weight(self.input_dim, self.output_dim)
self.set_weights([weight]) # overwrite self.weight to correct value
@staticmethod
def _init_weight(n_pos, dim):
"""
Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in
the 2nd half of the vector. [dim // 2:]
"""
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
# index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc, dtype=tf.float32)
tf.stop_gradient(table)
return table
def call(self, input_ids, use_cache=False):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = shape_list(input_ids)[:2]
if use_cache:
positions = tf.fill((1, 1), seq_len - 1)
else:
# starts at 0, ends at 1-seq_len
positions = tf.range(0, seq_len, delta=1, dtype=tf.int32, name="range")
return super().call(positions)
# Public API
[docs]@add_start_docstrings(
"The bare BART Model outputting raw hidden-states without any specific head on top.",
BART_START_DOCSTRING,
)
@keras_serializable
class TFBartModel(TFPretrainedBartModel):
def __init__(self, config: BartConfig, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, config.pad_token_id, name="model.shared")
with tf.compat.v1.variable_scope("model.shared") as shared_abs_scope_name:
pass
# Wraps layer to avoid problems with weight restoring and ensuring we're in the correct TF scope.
embed_tokens = TFWrappedEmbeddings(self.shared, abs_scope_name=shared_abs_scope_name)
embed_tokens.vocab_size = self.shared.vocab_size
embed_tokens.hidden_size = self.shared.hidden_size
self.encoder = TFBartEncoder(config, embed_tokens, name="encoder")
self.decoder = TFBartDecoder(config, embed_tokens, name="decoder")
def _prepare_bart_decoder_inputs(
self,
inputs,
decoder_input_ids=None,
decoder_attn_mask=None,
mask_dtype=None,
):
"""
Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if none are provided.
This mimics the default behavior in fairseq. To override it pass in masks.
"""
pad_token_id = self.config.pad_token_id
if decoder_input_ids is None:
decoder_input_ids = self._shift_right(inputs)
bsz, tgt_len = shape_list(decoder_input_ids)[:2]
if decoder_attn_mask is None:
decoder_padding_mask = make_padding_mask(decoder_input_ids, pad_token_id)
else:
decoder_padding_mask = invert_mask(decoder_attn_mask)
causal_lm_mask = causal_attention_mask(tgt_len, tgt_len, mask_dtype)
return decoder_input_ids, decoder_padding_mask, causal_lm_mask
[docs] @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask=None,
decoder_input_ids=None, # BAD DEFAULT LEFT FOR CONSISTENT SIGNATURE
decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
"""
Returns:
"""
assert "decoder_cached_states" not in kwargs, "Please use past_key_values to cache intermediate outputs"
if isinstance(inputs, (tuple, list)):
assert len(inputs) <= 10, "Too many inputs."
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
use_cache = inputs[6] if len(inputs) > 6 else use_cache
output_attentions = inputs[7] if len(inputs) > 7 else output_attentions
output_hidden_states = inputs[8] if len(inputs) > 8 else output_hidden_states
return_dict = inputs[9] if len(inputs) > 9 else return_dict
elif isinstance(inputs, (dict, BatchEncoding)):
assert len(inputs) <= 10, "Too many inputs."
if "inputs" in inputs:
raise ValueError("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
else:
input_ids = inputs
use_cache = use_cache if use_cache is not None else self.config.use_cache
if decoder_input_ids is None: # Classification
use_cache = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if not use_cache:
decoder_input_ids, decoder_padding_mask, causal_mask = self._prepare_bart_decoder_inputs(
inputs,
decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask,
mask_dtype=self.shared.dtype,
)
else:
decoder_padding_mask, causal_mask = None, None
assert (
isinstance(encoder_outputs, TFBaseModelOutput) or encoder_outputs is None
), f"got unexpected encoder outputs type {type(encoder_outputs)}"
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
training=training,
)
decoder_outputs = self.decoder(
decoder_input_ids,
encoder_outputs.last_hidden_state,
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_cached_states=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
if not return_dict:
# Attention and hidden_states will be [] or None if they aren't needed
return tuple(x for x in decoder_outputs + encoder_outputs.to_tuple() if x is not None)
else:
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
def get_input_embeddings(self):
return self.shared
def set_input_embeddings(self, value):
self.shared = value
def get_output_embeddings(self):
return self.shared
[docs]@add_start_docstrings(
"The BART Model with a language modeling head. Can be used for summarization.",
BART_START_DOCSTRING,
)
class TFBartForConditionalGeneration(TFPretrainedBartModel):
base_model_prefix = "model"
authorized_missing_keys = [
r"final_logits_bias",
]
authorized_unexpected_keys = [
r"model.encoder.embed_tokens.weight",
r"model.decoder.embed_tokens.weight",
]
def __init__(self, config: BartConfig, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self.model = TFBartModel(config, name="model")
self.use_cache = config.use_cache
# final_bias_logits is registered as a buffer in pytorch, so not trainable for the the sake of consistency.
self.final_logits_bias = self.add_weight(
name="/final_logits_bias", shape=[1, config.vocab_size], initializer="zeros", trainable=False
)
[docs] @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=TFSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
inputs,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs: Optional[TFBaseModelOutput] = None,
past_key_values=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
"""
Returns:
Examples::
# Mask filling only works for bart-large
from transformers import BartTokenizer, TFBartForConditionalGeneration
import tensorflow as tf
mname = 'facebook/bart-large'
tokenizer = BartTokenizer.from_pretrained(mname)
TXT = "My friends are <mask> but they eat too many carbs."
model = TFBartForConditionalGeneration.from_pretrained(mname)
batch = tokenizer([TXT], return_tensors='tf')
logits = model(inputs=batch.input_ids, return_dict=True).logits
probs = tf.nn.softmax(logits[0])
# probs[5] is associated with the mask token
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
labels = inputs[6] if len(inputs) > 6 else labels
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
return_dict = inputs[10] if len(inputs) > 10 else return_dict
assert len(inputs) <= 13, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
if "inputs" in inputs:
warnings.warn("Using `inputs` as a keyword argument is deprecated. Please use `input_ids` instead.")
if "past_key_value_states" in inputs:
raise ValueError(PAST_KV_DEPRECATION_WARNING)
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
labels = inputs.get("labels", labels)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 13, "Too many inputs."
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
raise ValueError(PAST_KV_DEPRECATION_WARNING)
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
use_cache = use_cache if use_cache is not None else self.config.use_cache
if labels is not None:
use_cache = False
outputs: TFSeq2SeqModelOutput = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True, # TODO(SS): this may need to change to support compilation
)
logits = self.model.shared(outputs.last_hidden_state, mode="linear")
logits = logits + self.final_logits_bias
loss = None if labels is None else self.compute_loss(labels, logits)
past = outputs.past_key_values if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
if return_dict:
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
past_key_values=past, # index 1 of d outputs
decoder_hidden_states=outputs.decoder_hidden_states, # index 2 of d outputs
decoder_attentions=outputs.decoder_attentions, # index 3 of d outputs
encoder_last_hidden_state=outputs.last_hidden_state, # index 0 of encoder outputs
encoder_hidden_states=outputs.encoder_hidden_states, # 1 of e out
encoder_attentions=outputs.encoder_attentions, # 2 of e out
)
else:
if past is not None:
decoder_outputs = (past,)
else:
decoder_outputs = tuple(
[x for x in (outputs.decoder_hidden_states, outputs.decoder_attentions) if x is not None]
)
enc_out = (outputs.encoder_last_hidden_state, outputs.encoder_hidden_states, outputs.encoder_attentions)
encoder_outputs = tuple(x for x in enc_out if x is not None)
output: Tuple = (logits,) + decoder_outputs + encoder_outputs
return ((loss,) + output) if loss is not None else output
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache=True, **kwargs) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor)
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
decoder_cached_states = None
else:
assert len(past) == 2
encoder_outputs, decoder_cached_states = past
if isinstance(encoder_outputs, tuple):
assert isinstance(encoder_outputs[0], tf.Tensor)
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
decoder_cached_states
), f"decoder cached states must be truthy. got {decoder_cached_states} from the 2nd element of past"
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"inputs": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
@staticmethod
def _reorder_cache(past, beam_idx):
assert len(past) == 2
(encoder_out, decoder_cached_states) = past
reordered_past = []
for layer_past in decoder_cached_states:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
past = (encoder_out, reordered_past)
return past
def adjust_logits_during_generation(self, logits, cur_len, max_length):
if cur_len == 1 and self.config.force_bos_token_to_be_generated:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.bos_token_id, LARGE_NEGATIVE, logits)
elif cur_len == max_length - 1:
vocab_range = tf.constant(range(self.config.vocab_size))
return tf.where(vocab_range != self.config.eos_token_id, LARGE_NEGATIVE, logits)
else:
return logits
def get_output_embeddings(self):
return self.model.shared
def get_encoder(self):
return self.model.encoder
def compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE,
)
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.pad_token_id)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
return loss_fn(labels, reduced_logits)