Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Defines the Transformer model in TF 2.0. | |
Model paper: https://arxiv.org/pdf/1706.03762.pdf | |
Transformer model code source: https://github.com/tensorflow/tensor2tensor | |
""" | |
import tensorflow as tf, tf_keras | |
from official.legacy.transformer import attention_layer | |
from official.legacy.transformer import embedding_layer | |
from official.legacy.transformer import ffn_layer | |
from official.legacy.transformer import metrics | |
from official.legacy.transformer import model_utils | |
from official.legacy.transformer.utils.tokenizer import EOS_ID | |
from official.nlp.modeling.layers import position_embedding | |
from official.nlp.modeling.ops import beam_search | |
# Disable the not-callable lint error, since it claims many objects are not | |
# callable when they actually are. | |
# pylint: disable=not-callable | |
def create_model(params, is_train): | |
"""Creates transformer model.""" | |
with tf.name_scope("model"): | |
if is_train: | |
inputs = tf_keras.layers.Input((None,), dtype="int64", name="inputs") | |
targets = tf_keras.layers.Input((None,), dtype="int64", name="targets") | |
internal_model = Transformer(params, name="transformer_v2") | |
logits = internal_model([inputs, targets], training=is_train) | |
vocab_size = params["vocab_size"] | |
label_smoothing = params["label_smoothing"] | |
if params["enable_metrics_in_training"]: | |
logits = metrics.MetricLayer(vocab_size)([logits, targets]) | |
logits = tf_keras.layers.Lambda( | |
lambda x: x, name="logits", dtype=tf.float32)( | |
logits) | |
model = tf_keras.Model([inputs, targets], logits) | |
loss = metrics.transformer_loss(logits, targets, label_smoothing, | |
vocab_size) | |
model.add_loss(loss) | |
return model | |
else: | |
inputs = tf_keras.layers.Input((None,), dtype="int64", name="inputs") | |
internal_model = Transformer(params, name="transformer_v2") | |
ret = internal_model([inputs], training=is_train) | |
outputs, scores = ret["outputs"], ret["scores"] | |
return tf_keras.Model(inputs, [outputs, scores]) | |
class Transformer(tf_keras.Model): | |
"""Transformer model with Keras. | |
Implemented as described in: https://arxiv.org/pdf/1706.03762.pdf | |
The Transformer model consists of an encoder and decoder. The input is an int | |
sequence (or a batch of sequences). The encoder produces a continuous | |
representation, and the decoder uses the encoder output to generate | |
probabilities for the output sequence. | |
""" | |
def __init__(self, params, name=None): | |
"""Initialize layers to build Transformer model. | |
Args: | |
params: hyperparameter object defining layer sizes, dropout values, etc. | |
name: name of the model. | |
""" | |
super(Transformer, self).__init__(name=name) | |
self.params = params | |
self.embedding_softmax_layer = embedding_layer.EmbeddingSharedWeights( | |
params["vocab_size"], params["hidden_size"]) | |
self.encoder_stack = EncoderStack(params) | |
self.decoder_stack = DecoderStack(params) | |
self.position_embedding = position_embedding.RelativePositionEmbedding( | |
hidden_size=self.params["hidden_size"]) | |
def get_config(self): | |
return { | |
"params": self.params, | |
} | |
def call(self, inputs, training): | |
"""Calculate target logits or inferred target sequences. | |
Args: | |
inputs: input tensor list of size 1 or 2. | |
First item, inputs: int tensor with shape [batch_size, input_length]. | |
Second item (optional), targets: None or int tensor with shape | |
[batch_size, target_length]. | |
training: boolean, whether in training mode or not. | |
Returns: | |
If targets is defined, then return logits for each word in the target | |
sequence. float tensor with shape [batch_size, target_length, vocab_size] | |
If target is none, then generate output sequence one token at a time. | |
returns a dictionary { | |
outputs: int tensor with shape [batch_size, decoded_length] | |
scores: float tensor with shape [batch_size]} | |
Even when float16 is used, the output tensor(s) are always float32. | |
Raises: | |
NotImplementedError: If try to use padded decode method on CPU/GPUs. | |
""" | |
inputs = inputs if isinstance(inputs, list) else [inputs] | |
if len(inputs) == 2: | |
inputs, targets = inputs[0], inputs[1] | |
else: | |
# Decoding path. | |
inputs, targets = inputs[0], None | |
if self.params["padded_decode"]: | |
if not self.params["num_replicas"]: | |
raise NotImplementedError( | |
"Padded decoding on CPU/GPUs is not supported.") | |
decode_batch_size = int(self.params["decode_batch_size"] / | |
self.params["num_replicas"]) | |
inputs.set_shape([decode_batch_size, self.params["decode_max_length"]]) | |
# Variance scaling is used here because it seems to work in many problems. | |
# Other reasonable initializers may also work just as well. | |
with tf.name_scope("Transformer"): | |
# Calculate attention bias for encoder self-attention and decoder | |
# multi-headed attention layers. | |
attention_bias = model_utils.get_padding_bias(inputs) | |
# Run the inputs through the encoder layer to map the symbol | |
# representations to continuous representations. | |
encoder_outputs = self.encode(inputs, attention_bias, training) | |
# Generate output sequence if targets is None, or return logits if target | |
# sequence is known. | |
if targets is None: | |
return self.predict(encoder_outputs, attention_bias, training) | |
else: | |
logits = self.decode(targets, encoder_outputs, attention_bias, training) | |
return logits | |
def encode(self, inputs, attention_bias, training): | |
"""Generate continuous representation for inputs. | |
Args: | |
inputs: int tensor with shape [batch_size, input_length]. | |
attention_bias: float tensor with shape [batch_size, 1, 1, input_length]. | |
training: boolean, whether in training mode or not. | |
Returns: | |
float tensor with shape [batch_size, input_length, hidden_size] | |
""" | |
with tf.name_scope("encode"): | |
# Prepare inputs to the layer stack by adding positional encodings and | |
# applying dropout. | |
embedded_inputs = self.embedding_softmax_layer(inputs) | |
embedded_inputs = tf.cast(embedded_inputs, self.params["dtype"]) | |
inputs_padding = model_utils.get_padding(inputs) | |
attention_bias = tf.cast(attention_bias, self.params["dtype"]) | |
with tf.name_scope("add_pos_encoding"): | |
pos_encoding = self.position_embedding(inputs=embedded_inputs) | |
pos_encoding = tf.cast(pos_encoding, self.params["dtype"]) | |
encoder_inputs = embedded_inputs + pos_encoding | |
if training: | |
encoder_inputs = tf.nn.dropout( | |
encoder_inputs, rate=self.params["layer_postprocess_dropout"]) | |
return self.encoder_stack( | |
encoder_inputs, attention_bias, inputs_padding, training=training) | |
def decode(self, targets, encoder_outputs, attention_bias, training): | |
"""Generate logits for each value in the target sequence. | |
Args: | |
targets: target values for the output sequence. int tensor with shape | |
[batch_size, target_length] | |
encoder_outputs: continuous representation of input sequence. float tensor | |
with shape [batch_size, input_length, hidden_size] | |
attention_bias: float tensor with shape [batch_size, 1, 1, input_length] | |
training: boolean, whether in training mode or not. | |
Returns: | |
float32 tensor with shape [batch_size, target_length, vocab_size] | |
""" | |
with tf.name_scope("decode"): | |
# Prepare inputs to decoder layers by shifting targets, adding positional | |
# encoding and applying dropout. | |
with tf.name_scope("shift_targets"): | |
# Shift targets to the right, and remove the last element | |
targets = tf.pad(targets, [[0, 0], [1, 0]])[:, :-1] | |
decoder_inputs = self.embedding_softmax_layer(targets) | |
decoder_inputs = tf.cast(decoder_inputs, self.params["dtype"]) | |
attention_bias = tf.cast(attention_bias, self.params["dtype"]) | |
with tf.name_scope("add_pos_encoding"): | |
length = tf.shape(decoder_inputs)[1] | |
pos_encoding = self.position_embedding(decoder_inputs) | |
pos_encoding = tf.cast(pos_encoding, self.params["dtype"]) | |
decoder_inputs += pos_encoding | |
if training: | |
decoder_inputs = tf.nn.dropout( | |
decoder_inputs, rate=self.params["layer_postprocess_dropout"]) | |
# Run values | |
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( | |
length, dtype=self.params["dtype"]) | |
outputs = self.decoder_stack( | |
decoder_inputs, | |
encoder_outputs, | |
decoder_self_attention_bias, | |
attention_bias, | |
training=training) | |
logits = self.embedding_softmax_layer(outputs, mode="linear") | |
logits = tf.cast(logits, tf.float32) | |
return logits | |
def _get_symbols_to_logits_fn(self, max_decode_length, training): | |
"""Returns a decoding function that calculates logits of the next tokens.""" | |
timing_signal = self.position_embedding( | |
inputs=None, length=max_decode_length + 1) | |
timing_signal = tf.cast(timing_signal, self.params["dtype"]) | |
decoder_self_attention_bias = model_utils.get_decoder_self_attention_bias( | |
max_decode_length, dtype=self.params["dtype"]) | |
def symbols_to_logits_fn(ids, i, cache): | |
"""Generate logits for next potential IDs. | |
Args: | |
ids: Current decoded sequences. int tensor with shape [batch_size * | |
beam_size, i + 1]. | |
i: Loop index. | |
cache: dictionary of values storing the encoder output, encoder-decoder | |
attention bias, and previous decoder attention values. | |
Returns: | |
Tuple of | |
(logits with shape [batch_size * beam_size, vocab_size], | |
updated cache values) | |
""" | |
# Set decoder input to the last generated IDs | |
decoder_input = ids[:, -1:] | |
# Preprocess decoder input by getting embeddings and adding timing signal. | |
decoder_input = self.embedding_softmax_layer(decoder_input) | |
decoder_input += timing_signal[i] | |
if self.params["padded_decode"]: | |
bias_shape = decoder_self_attention_bias.shape.as_list() | |
self_attention_bias = tf.slice( | |
decoder_self_attention_bias, [0, 0, i, 0], | |
[bias_shape[0], bias_shape[1], 1, bias_shape[3]]) | |
else: | |
self_attention_bias = decoder_self_attention_bias[:, :, i:i + 1, :i + 1] | |
decoder_outputs = self.decoder_stack( | |
decoder_input, | |
cache.get("encoder_outputs"), | |
self_attention_bias, | |
cache.get("encoder_decoder_attention_bias"), | |
training=training, | |
cache=cache, | |
decode_loop_step=i if self.params["padded_decode"] else None) | |
logits = self.embedding_softmax_layer(decoder_outputs, mode="linear") | |
logits = tf.squeeze(logits, axis=[1]) | |
return logits, cache | |
return symbols_to_logits_fn | |
def predict(self, encoder_outputs, encoder_decoder_attention_bias, training): | |
"""Return predicted sequence.""" | |
encoder_outputs = tf.cast(encoder_outputs, self.params["dtype"]) | |
if self.params["padded_decode"]: | |
batch_size = encoder_outputs.shape.as_list()[0] | |
input_length = encoder_outputs.shape.as_list()[1] | |
else: | |
batch_size = tf.shape(encoder_outputs)[0] | |
input_length = tf.shape(encoder_outputs)[1] | |
max_decode_length = input_length + self.params["extra_decode_length"] | |
encoder_decoder_attention_bias = tf.cast(encoder_decoder_attention_bias, | |
self.params["dtype"]) | |
symbols_to_logits_fn = self._get_symbols_to_logits_fn( | |
max_decode_length, training) | |
# Create initial set of IDs that will be passed into symbols_to_logits_fn. | |
initial_ids = tf.zeros([batch_size], dtype=tf.int32) | |
# Create cache storing decoder attention values for each layer. | |
# pylint: disable=g-complex-comprehension | |
init_decode_length = ( | |
max_decode_length if self.params["padded_decode"] else 0) | |
num_heads = self.params["num_heads"] | |
dim_per_head = self.params["hidden_size"] // num_heads | |
cache = { | |
"layer_%d" % layer: { | |
"k": | |
tf.zeros( | |
[batch_size, init_decode_length, num_heads, dim_per_head], | |
dtype=self.params["dtype"]), | |
"v": | |
tf.zeros( | |
[batch_size, init_decode_length, num_heads, dim_per_head], | |
dtype=self.params["dtype"]) | |
} for layer in range(self.params["num_hidden_layers"]) | |
} | |
# pylint: enable=g-complex-comprehension | |
# Add encoder output and attention bias to the cache. | |
cache["encoder_outputs"] = encoder_outputs | |
cache["encoder_decoder_attention_bias"] = encoder_decoder_attention_bias | |
# Use beam search to find the top beam_size sequences and scores. | |
decoded_ids, scores = beam_search.sequence_beam_search( | |
symbols_to_logits_fn=symbols_to_logits_fn, | |
initial_ids=initial_ids, | |
initial_cache=cache, | |
vocab_size=self.params["vocab_size"], | |
beam_size=self.params["beam_size"], | |
alpha=self.params["alpha"], | |
max_decode_length=max_decode_length, | |
eos_id=EOS_ID, | |
padded_decode=self.params["padded_decode"], | |
dtype=self.params["dtype"]) | |
# Get the top sequence for each batch element | |
top_decoded_ids = decoded_ids[:, 0, 1:] | |
top_scores = scores[:, 0] | |
return {"outputs": top_decoded_ids, "scores": top_scores} | |
class PrePostProcessingWrapper(tf_keras.layers.Layer): | |
"""Wrapper class that applies layer pre-processing and post-processing.""" | |
def __init__(self, layer, params): | |
super(PrePostProcessingWrapper, self).__init__() | |
self.layer = layer | |
self.params = params | |
self.postprocess_dropout = params["layer_postprocess_dropout"] | |
def build(self, input_shape): | |
# Create normalization layer | |
self.layer_norm = tf_keras.layers.LayerNormalization( | |
epsilon=1e-6, dtype="float32") | |
super(PrePostProcessingWrapper, self).build(input_shape) | |
def get_config(self): | |
return { | |
"params": self.params, | |
} | |
def call(self, x, *args, **kwargs): | |
"""Calls wrapped layer with same parameters.""" | |
# Preprocessing: apply layer normalization | |
training = kwargs["training"] | |
y = self.layer_norm(x) | |
# Get layer output | |
y = self.layer(y, *args, **kwargs) | |
# Postprocessing: apply dropout and residual connection | |
if training: | |
y = tf.nn.dropout(y, rate=self.postprocess_dropout) | |
return x + y | |
class EncoderStack(tf_keras.layers.Layer): | |
"""Transformer encoder stack. | |
The encoder stack is made up of N identical layers. Each layer is composed | |
of the sublayers: | |
1. Self-attention layer | |
2. Feedforward network (which is 2 fully-connected layers) | |
""" | |
def __init__(self, params): | |
super(EncoderStack, self).__init__() | |
self.params = params | |
self.layers = [] | |
def build(self, input_shape): | |
"""Builds the encoder stack.""" | |
params = self.params | |
for _ in range(params["num_hidden_layers"]): | |
# Create sublayers for each layer. | |
self_attention_layer = attention_layer.SelfAttention( | |
params["hidden_size"], params["num_heads"], | |
params["attention_dropout"]) | |
feed_forward_network = ffn_layer.FeedForwardNetwork( | |
params["hidden_size"], params["filter_size"], params["relu_dropout"]) | |
self.layers.append([ | |
PrePostProcessingWrapper(self_attention_layer, params), | |
PrePostProcessingWrapper(feed_forward_network, params) | |
]) | |
# Create final layer normalization layer. | |
self.output_normalization = tf_keras.layers.LayerNormalization( | |
epsilon=1e-6, dtype="float32") | |
super(EncoderStack, self).build(input_shape) | |
def get_config(self): | |
return { | |
"params": self.params, | |
} | |
def call(self, encoder_inputs, attention_bias, inputs_padding, training): | |
"""Return the output of the encoder layer stacks. | |
Args: | |
encoder_inputs: tensor with shape [batch_size, input_length, hidden_size] | |
attention_bias: bias for the encoder self-attention layer. [batch_size, 1, | |
1, input_length] | |
inputs_padding: tensor with shape [batch_size, input_length], inputs with | |
zero paddings. | |
training: boolean, whether in training mode or not. | |
Returns: | |
Output of encoder layer stack. | |
float32 tensor with shape [batch_size, input_length, hidden_size] | |
""" | |
for n, layer in enumerate(self.layers): | |
# Run inputs through the sublayers. | |
self_attention_layer = layer[0] | |
feed_forward_network = layer[1] | |
with tf.name_scope("layer_%d" % n): | |
with tf.name_scope("self_attention"): | |
encoder_inputs = self_attention_layer( | |
encoder_inputs, attention_bias, training=training) | |
with tf.name_scope("ffn"): | |
encoder_inputs = feed_forward_network( | |
encoder_inputs, training=training) | |
return self.output_normalization(encoder_inputs) | |
class DecoderStack(tf_keras.layers.Layer): | |
"""Transformer decoder stack. | |
Like the encoder stack, the decoder stack is made up of N identical layers. | |
Each layer is composed of the sublayers: | |
1. Self-attention layer | |
2. Multi-headed attention layer combining encoder outputs with results from | |
the previous self-attention layer. | |
3. Feedforward network (2 fully-connected layers) | |
""" | |
def __init__(self, params): | |
super(DecoderStack, self).__init__() | |
self.params = params | |
self.layers = [] | |
def build(self, input_shape): | |
"""Builds the decoder stack.""" | |
params = self.params | |
for _ in range(params["num_hidden_layers"]): | |
self_attention_layer = attention_layer.SelfAttention( | |
params["hidden_size"], params["num_heads"], | |
params["attention_dropout"]) | |
enc_dec_attention_layer = attention_layer.Attention( | |
params["hidden_size"], params["num_heads"], | |
params["attention_dropout"]) | |
feed_forward_network = ffn_layer.FeedForwardNetwork( | |
params["hidden_size"], params["filter_size"], params["relu_dropout"]) | |
self.layers.append([ | |
PrePostProcessingWrapper(self_attention_layer, params), | |
PrePostProcessingWrapper(enc_dec_attention_layer, params), | |
PrePostProcessingWrapper(feed_forward_network, params) | |
]) | |
self.output_normalization = tf_keras.layers.LayerNormalization( | |
epsilon=1e-6, dtype="float32") | |
super(DecoderStack, self).build(input_shape) | |
def get_config(self): | |
return { | |
"params": self.params, | |
} | |
def call(self, | |
decoder_inputs, | |
encoder_outputs, | |
decoder_self_attention_bias, | |
attention_bias, | |
training, | |
cache=None, | |
decode_loop_step=None): | |
"""Return the output of the decoder layer stacks. | |
Args: | |
decoder_inputs: A tensor with shape [batch_size, target_length, | |
hidden_size]. | |
encoder_outputs: A tensor with shape [batch_size, input_length, | |
hidden_size] | |
decoder_self_attention_bias: A tensor with shape [1, 1, target_len, | |
target_length], the bias for decoder self-attention layer. | |
attention_bias: A tensor with shape [batch_size, 1, 1, input_length], the | |
bias for encoder-decoder attention layer. | |
training: A bool, whether in training mode or not. | |
cache: (Used for fast decoding) A nested dictionary storing previous | |
decoder self-attention values. The items are: | |
{layer_n: {"k": A tensor with shape [batch_size, i, key_channels], | |
"v": A tensor with shape [batch_size, i, value_channels]}, | |
...} | |
decode_loop_step: An integer, the step number of the decoding loop. Used | |
only for autoregressive inference on TPU. | |
Returns: | |
Output of decoder layer stack. | |
float32 tensor with shape [batch_size, target_length, hidden_size] | |
""" | |
for n, layer in enumerate(self.layers): | |
self_attention_layer = layer[0] | |
enc_dec_attention_layer = layer[1] | |
feed_forward_network = layer[2] | |
# Run inputs through the sublayers. | |
layer_name = "layer_%d" % n | |
layer_cache = cache[layer_name] if cache is not None else None | |
with tf.name_scope(layer_name): | |
with tf.name_scope("self_attention"): | |
decoder_inputs = self_attention_layer( | |
decoder_inputs, | |
decoder_self_attention_bias, | |
training=training, | |
cache=layer_cache, | |
decode_loop_step=decode_loop_step) | |
with tf.name_scope("encdec_attention"): | |
decoder_inputs = enc_dec_attention_layer( | |
decoder_inputs, | |
encoder_outputs, | |
attention_bias, | |
training=training) | |
with tf.name_scope("ffn"): | |
decoder_inputs = feed_forward_network( | |
decoder_inputs, training=training) | |
return self.output_normalization(decoder_inputs) | |