deanna-emery's picture
updates
93528c6
raw
history blame
23.5 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MobileBERT embedding and transformer layers."""
import tensorflow as tf, tf_keras
from official.modeling import tf_utils
from official.nlp.modeling.layers import on_device_embedding
from official.nlp.modeling.layers import position_embedding
@tf_keras.utils.register_keras_serializable(package='Text')
class NoNorm(tf_keras.layers.Layer):
"""Apply element-wise linear transformation to the last dimension."""
def __init__(self, name=None):
super().__init__(name=name)
def build(self, shape):
kernal_size = shape[-1]
self.bias = self.add_weight('beta',
shape=[kernal_size],
initializer='zeros')
self.scale = self.add_weight('gamma',
shape=[kernal_size],
initializer='ones')
def call(self, feature):
output = feature * self.scale + self.bias
return output
def _get_norm_layer(normalization_type='no_norm', name=None):
"""Get normlization layer.
Args:
normalization_type: String. The type of normalization_type, only
`no_norm` and `layer_norm` are supported.
name: Name for the norm layer.
Returns:
layer norm class.
"""
if normalization_type == 'no_norm':
layer = NoNorm(name=name)
elif normalization_type == 'layer_norm':
layer = tf_keras.layers.LayerNormalization(
name=name,
axis=-1,
epsilon=1e-12,
dtype=tf.float32)
else:
raise NotImplementedError('Only "no_norm" and "layer_norm" and supported.')
return layer
@tf_keras.utils.register_keras_serializable(package='Text')
class MobileBertEmbedding(tf_keras.layers.Layer):
"""Performs an embedding lookup for MobileBERT.
This layer includes word embedding, token type embedding, position embedding.
"""
def __init__(self,
word_vocab_size,
word_embed_size,
type_vocab_size,
output_embed_size,
max_sequence_length=512,
normalization_type='no_norm',
initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02),
dropout_rate=0.1,
**kwargs):
"""Class initialization.
Args:
word_vocab_size: Number of words in the vocabulary.
word_embed_size: Word embedding size.
type_vocab_size: Number of word types.
output_embed_size: Embedding size for the final embedding output.
max_sequence_length: Maximum length of input sequence.
normalization_type: String. The type of normalization_type, only
`no_norm` and `layer_norm` are supported.
initializer: The initializer to use for the embedding weights and
linear projection weights.
dropout_rate: Dropout rate.
**kwargs: keyword arguments.
"""
super().__init__(**kwargs)
self.word_vocab_size = word_vocab_size
self.word_embed_size = word_embed_size
self.type_vocab_size = type_vocab_size
self.output_embed_size = output_embed_size
self.max_sequence_length = max_sequence_length
self.normalization_type = normalization_type
self.initializer = tf_keras.initializers.get(initializer)
self.dropout_rate = dropout_rate
self.word_embedding = on_device_embedding.OnDeviceEmbedding(
self.word_vocab_size,
self.word_embed_size,
initializer=tf_utils.clone_initializer(self.initializer),
name='word_embedding')
self.type_embedding = on_device_embedding.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
initializer=tf_utils.clone_initializer(self.initializer),
name='type_embedding')
self.pos_embedding = position_embedding.PositionEmbedding(
max_length=max_sequence_length,
initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding')
self.word_embedding_proj = tf_keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.output_embed_size],
kernel_initializer=tf_utils.clone_initializer(self.initializer),
bias_axes='d',
name='embedding_projection')
self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
self.dropout_layer = tf_keras.layers.Dropout(
self.dropout_rate,
name='embedding_dropout')
def get_config(self):
config = {
'word_vocab_size': self.word_vocab_size,
'word_embed_size': self.word_embed_size,
'type_vocab_size': self.type_vocab_size,
'output_embed_size': self.output_embed_size,
'max_sequence_length': self.max_sequence_length,
'normalization_type': self.normalization_type,
'initializer': tf_keras.initializers.serialize(self.initializer),
'dropout_rate': self.dropout_rate
}
base_config = super(MobileBertEmbedding, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids, token_type_ids=None):
word_embedding_out = self.word_embedding(input_ids)
word_embedding_out = tf.concat(
[tf.pad(word_embedding_out[:, 1:], ((0, 0), (0, 1), (0, 0))),
word_embedding_out,
tf.pad(word_embedding_out[:, :-1], ((0, 0), (1, 0), (0, 0)))],
axis=2)
word_embedding_out = self.word_embedding_proj(word_embedding_out)
pos_embedding_out = self.pos_embedding(word_embedding_out)
embedding_out = word_embedding_out + pos_embedding_out
if token_type_ids is not None:
type_embedding_out = self.type_embedding(token_type_ids)
embedding_out += type_embedding_out
embedding_out = self.layer_norm(embedding_out)
embedding_out = self.dropout_layer(embedding_out)
return embedding_out
@tf_keras.utils.register_keras_serializable(package='Text')
class MobileBertTransformer(tf_keras.layers.Layer):
"""Transformer block for MobileBERT.
An implementation of one layer (block) of Transformer with bottleneck and
inverted-bottleneck for MobilerBERT.
Original paper for MobileBERT:
https://arxiv.org/pdf/2004.02984.pdf
"""
def __init__(self,
hidden_size=512,
num_attention_heads=4,
intermediate_size=512,
intermediate_act_fn='relu',
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
intra_bottleneck_size=128,
use_bottleneck_attention=False,
key_query_shared_bottleneck=True,
num_feedforward_networks=4,
normalization_type='no_norm',
initializer=tf_keras.initializers.TruncatedNormal(stddev=0.02),
**kwargs):
"""Class initialization.
Args:
hidden_size: Hidden size for the Transformer input and output tensor.
num_attention_heads: Number of attention heads in the Transformer.
intermediate_size: The size of the "intermediate" (a.k.a., feed
forward) layer.
intermediate_act_fn: The non-linear activation function to apply
to the output of the intermediate/feed-forward layer.
hidden_dropout_prob: Dropout probability for the hidden layers.
attention_probs_dropout_prob: Dropout probability of the attention
probabilities.
intra_bottleneck_size: Size of bottleneck.
use_bottleneck_attention: Use attention inputs from the bottleneck
transformation. If true, the following `key_query_shared_bottleneck`
will be ignored.
key_query_shared_bottleneck: Whether to share linear transformation for
keys and queries.
num_feedforward_networks: Number of stacked feed-forward networks.
normalization_type: The type of normalization_type, only `no_norm` and
`layer_norm` are supported. `no_norm` represents the element-wise
linear transformation for the student model, as suggested by the
original MobileBERT paper. `layer_norm` is used for the teacher model.
initializer: The initializer to use for the embedding weights and
linear projection weights.
**kwargs: keyword arguments.
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.intermediate_act_fn = intermediate_act_fn
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.intra_bottleneck_size = intra_bottleneck_size
self.use_bottleneck_attention = use_bottleneck_attention
self.key_query_shared_bottleneck = key_query_shared_bottleneck
self.num_feedforward_networks = num_feedforward_networks
self.normalization_type = normalization_type
self.initializer = tf_keras.initializers.get(initializer)
if intra_bottleneck_size % num_attention_heads != 0:
raise ValueError(
(f'The bottleneck size {intra_bottleneck_size} is not a multiple '
f'of the number of attention heads {num_attention_heads}.'))
attention_head_size = int(intra_bottleneck_size / num_attention_heads)
self.block_layers = {}
# add input bottleneck
dense_layer_2d = tf_keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_input/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='bottleneck_input/norm')
self.block_layers['bottleneck_input'] = [dense_layer_2d,
layer_norm]
if self.key_query_shared_bottleneck:
dense_layer_2d = tf_keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='kq_shared_bottleneck/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='kq_shared_bottleneck/norm')
self.block_layers['kq_shared_bottleneck'] = [dense_layer_2d,
layer_norm]
# add attention layer
attention_layer = tf_keras.layers.MultiHeadAttention(
num_heads=self.num_attention_heads,
key_dim=attention_head_size,
value_dim=attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.intra_bottleneck_size,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='attention')
layer_norm = _get_norm_layer(self.normalization_type,
name='attention/norm')
self.block_layers['attention'] = [attention_layer,
layer_norm]
# add stacked feed-forward networks
self.block_layers['ffn'] = []
for ffn_layer_idx in range(self.num_feedforward_networks):
layer_prefix = f'ffn_layer_{ffn_layer_idx}'
layer_name = layer_prefix + '/intermediate_dense'
intermediate_layer = tf_keras.layers.EinsumDense(
'abc,cd->abd',
activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size],
bias_axes='d',
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/output_dense'
output_layer = tf_keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/norm'
layer_norm = _get_norm_layer(self.normalization_type,
name=layer_name)
self.block_layers['ffn'].append([intermediate_layer,
output_layer,
layer_norm])
# add output bottleneck
bottleneck = tf_keras.layers.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.hidden_size],
activation=None,
bias_axes='d',
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_output/dense')
dropout_layer = tf_keras.layers.Dropout(
self.hidden_dropout_prob,
name='bottleneck_output/dropout')
layer_norm = _get_norm_layer(self.normalization_type,
name='bottleneck_output/norm')
self.block_layers['bottleneck_output'] = [bottleneck,
dropout_layer,
layer_norm]
def get_config(self):
config = {
'hidden_size': self.hidden_size,
'num_attention_heads': self.num_attention_heads,
'intermediate_size': self.intermediate_size,
'intermediate_act_fn': self.intermediate_act_fn,
'hidden_dropout_prob': self.hidden_dropout_prob,
'attention_probs_dropout_prob': self.attention_probs_dropout_prob,
'intra_bottleneck_size': self.intra_bottleneck_size,
'use_bottleneck_attention': self.use_bottleneck_attention,
'key_query_shared_bottleneck': self.key_query_shared_bottleneck,
'num_feedforward_networks': self.num_feedforward_networks,
'normalization_type': self.normalization_type,
'initializer': tf_keras.initializers.serialize(self.initializer),
}
base_config = super(MobileBertTransformer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self,
input_tensor,
attention_mask=None,
return_attention_scores=False):
"""Implementes the forward pass.
Args:
input_tensor: Float tensor of shape
`(batch_size, seq_length, hidden_size)`.
attention_mask: (optional) int32 tensor of shape
`(batch_size, seq_length, seq_length)`, with 1 for positions that can
be attended to and 0 in positions that should not be.
return_attention_scores: If return attention score.
Returns:
layer_output: Float tensor of shape
`(batch_size, seq_length, hidden_size)`.
attention_scores (Optional): Only when return_attention_scores is True.
Raises:
ValueError: A Tensor shape or parameter is invalid.
"""
input_width = input_tensor.shape.as_list()[-1]
if input_width != self.hidden_size:
raise ValueError(
(f'The width of the input tensor {input_width} != '
f'hidden size {self.hidden_size}'))
prev_output = input_tensor
# input bottleneck
dense_layer = self.block_layers['bottleneck_input'][0]
layer_norm = self.block_layers['bottleneck_input'][1]
layer_input = dense_layer(prev_output)
layer_input = layer_norm(layer_input)
if self.use_bottleneck_attention:
key_tensor = layer_input
query_tensor = layer_input
value_tensor = layer_input
elif self.key_query_shared_bottleneck:
dense_layer = self.block_layers['kq_shared_bottleneck'][0]
layer_norm = self.block_layers['kq_shared_bottleneck'][1]
shared_attention_input = dense_layer(prev_output)
shared_attention_input = layer_norm(shared_attention_input)
key_tensor = shared_attention_input
query_tensor = shared_attention_input
value_tensor = prev_output
else:
key_tensor = prev_output
query_tensor = prev_output
value_tensor = prev_output
# attention layer
attention_layer = self.block_layers['attention'][0]
layer_norm = self.block_layers['attention'][1]
attention_output, attention_scores = attention_layer(
query_tensor,
value_tensor,
key_tensor,
attention_mask,
return_attention_scores=True,
)
attention_output = layer_norm(attention_output + layer_input)
# stacked feed-forward networks
layer_input = attention_output
for ffn_idx in range(self.num_feedforward_networks):
intermediate_layer = self.block_layers['ffn'][ffn_idx][0]
output_layer = self.block_layers['ffn'][ffn_idx][1]
layer_norm = self.block_layers['ffn'][ffn_idx][2]
intermediate_output = intermediate_layer(layer_input)
layer_output = output_layer(intermediate_output)
layer_output = layer_norm(layer_output + layer_input)
layer_input = layer_output
# output bottleneck
bottleneck = self.block_layers['bottleneck_output'][0]
dropout_layer = self.block_layers['bottleneck_output'][1]
layer_norm = self.block_layers['bottleneck_output'][2]
layer_output = bottleneck(layer_output)
layer_output = dropout_layer(layer_output)
layer_output = layer_norm(layer_output + prev_output)
if return_attention_scores:
return layer_output, attention_scores
else:
return layer_output
@tf_keras.utils.register_keras_serializable(package='Text')
class MobileBertMaskedLM(tf_keras.layers.Layer):
"""Masked language model network head for BERT modeling.
This layer implements a masked language model based on the provided
transformer based encoder. It assumes that the encoder network being passed
has a "get_embedding_table()" method. Different from canonical BERT's masked
LM layer, when the embedding width is smaller than hidden_size, it adds an
extra output weights in shape [vocab_size, (hidden_size - embedding_width)].
"""
def __init__(self,
embedding_table,
activation=None,
initializer='glorot_uniform',
output='logits',
output_weights_use_proj=False,
**kwargs):
"""Class initialization.
Args:
embedding_table: The embedding table from encoder network.
activation: The activation, if any, for the dense layer.
initializer: The initializer for the dense layer. Defaults to a Glorot
uniform initializer.
output: The output style for this layer. Can be either `logits` or
`predictions`.
output_weights_use_proj: Use projection instead of concating extra output
weights, this may reduce the MLM task accuracy but will reduce the model
params as well.
**kwargs: keyword arguments.
"""
super().__init__(**kwargs)
self.embedding_table = embedding_table
self.activation = activation
self.initializer = tf_keras.initializers.get(initializer)
if output not in ('predictions', 'logits'):
raise ValueError(
('Unknown `output` value "%s". `output` can be either "logits" or '
'"predictions"') % output)
self._output_type = output
self._output_weights_use_proj = output_weights_use_proj
def build(self, input_shape):
self._vocab_size, embedding_width = self.embedding_table.shape
hidden_size = input_shape[-1]
self.dense = tf_keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/dense')
if hidden_size > embedding_width:
if self._output_weights_use_proj:
self.extra_output_weights = self.add_weight(
'output_weights_proj',
shape=(embedding_width, hidden_size),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
else:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
elif hidden_size == embedding_width:
self.extra_output_weights = None
else:
raise ValueError(
'hidden size %d cannot be smaller than embedding width %d.' %
(hidden_size, embedding_width))
self.layer_norm = tf_keras.layers.LayerNormalization(
axis=-1, epsilon=1e-12, name='transform/LayerNorm')
self.bias = self.add_weight(
'output_bias/bias',
shape=(self._vocab_size,),
initializer='zeros',
trainable=True)
super(MobileBertMaskedLM, self).build(input_shape)
def call(self, sequence_data, masked_positions):
masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
lm_data = self.dense(masked_lm_input)
lm_data = self.layer_norm(lm_data)
if self.extra_output_weights is None:
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
if self._output_weights_use_proj:
lm_data = tf.matmul(
lm_data, self.extra_output_weights, transpose_b=True)
lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
else:
lm_data = tf.matmul(
lm_data,
tf.concat([self.embedding_table, self.extra_output_weights],
axis=1),
transpose_b=True)
logits = tf.nn.bias_add(lm_data, self.bias)
masked_positions_length = masked_positions.shape.as_list()[1] or tf.shape(
masked_positions)[1]
logits = tf.reshape(logits,
[-1, masked_positions_length, self._vocab_size])
if self._output_type == 'logits':
return logits
return tf.nn.log_softmax(logits)
def get_config(self):
raise NotImplementedError('MaskedLM cannot be directly serialized because '
'it has variable sharing logic.')
def _gather_indexes(self, sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
`(batch_size, seq_length, num_hidden)` where `num_hidden` is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining
of with dimension `(batch_size, num_predictions)` where
`num_predictions` is maximum number of tokens to mask out and predict
per each sequence.
Returns:
Masked out sequence tensor of shape
`(batch_size * num_predictions, num_hidden)`.
"""
sequence_shape = tf.shape(sequence_tensor)
batch_size, seq_length = sequence_shape[0], sequence_shape[1]
width = sequence_tensor.shape.as_list()[2] or sequence_shape[2]
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor