TymaaHammouda's picture
Upload 106 files
ceed500 verified
raw
history blame
9.03 kB
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for pre-training. These mainly deal with the gathering and
scattering needed so the generator only makes predictions for the small number
of masked tokens.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
import configure_pretraining
from model import modeling
from model import tokenization
from pretrain import pretrain_data
def gather_positions(sequence, positions):
"""Gathers the vectors at the specific positions over a minibatch.
Args:
sequence: A [batch_size, seq_length] or
[batch_size, seq_length, depth] tensor of values
positions: A [batch_size, n_positions] tensor of indices
Returns: A [batch_size, n_positions] or
[batch_size, n_positions, depth] tensor of the values at the indices
"""
shape = modeling.get_shape_list(sequence, expected_rank=[2, 3])
depth_dimension = (len(shape) == 3)
if depth_dimension:
B, L, D = shape
else:
B, L = shape
D = 1
sequence = tf.expand_dims(sequence, -1)
position_shift = tf.expand_dims(L * tf.range(B), -1)
flat_positions = tf.reshape(positions + position_shift, [-1])
flat_sequence = tf.reshape(sequence, [B * L, D])
gathered = tf.gather(flat_sequence, flat_positions)
if depth_dimension:
return tf.reshape(gathered, [B, -1, D])
else:
return tf.reshape(gathered, [B, -1])
def scatter_update(sequence, updates, positions):
"""Scatter-update a sequence.
Args:
sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
updates: A tensor of size batch_size*seq_len(*depth)
positions: A [batch_size, n_positions] tensor
Returns: A tuple of two tensors. First is a [batch_size, seq_len] or
[batch_size, seq_len, depth] tensor of "sequence" with elements at
"positions" replaced by the values at "updates." Updates to index 0 are
ignored. If there are duplicated positions the update is only applied once.
Second is a [batch_size, seq_len] mask tensor of which inputs were updated.
"""
shape = modeling.get_shape_list(sequence, expected_rank=[2, 3])
depth_dimension = (len(shape) == 3)
if depth_dimension:
B, L, D = shape
else:
B, L = shape
D = 1
sequence = tf.expand_dims(sequence, -1)
N = modeling.get_shape_list(positions)[1]
shift = tf.expand_dims(L * tf.range(B), -1)
flat_positions = tf.reshape(positions + shift, [-1, 1])
flat_updates = tf.reshape(updates, [-1, D])
updates = tf.scatter_nd(flat_positions, flat_updates, [B * L, D])
updates = tf.reshape(updates, [B, L, D])
flat_updates_mask = tf.ones([B * N], tf.int32)
updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask, [B * L])
updates_mask = tf.reshape(updates_mask, [B, L])
not_first_token = tf.concat([tf.zeros((B, 1), tf.int32),
tf.ones((B, L - 1), tf.int32)], -1)
updates_mask *= not_first_token
updates_mask_3d = tf.expand_dims(updates_mask, -1)
# account for duplicate positions
if sequence.dtype == tf.float32:
updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
updates /= tf.maximum(1.0, updates_mask_3d)
else:
assert sequence.dtype == tf.int32
updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d))
updates_mask = tf.minimum(updates_mask, 1)
updates_mask_3d = tf.minimum(updates_mask_3d, 1)
updated_sequence = (((1 - updates_mask_3d) * sequence) +
(updates_mask_3d * updates))
if not depth_dimension:
updated_sequence = tf.squeeze(updated_sequence, -1)
return updated_sequence, updates_mask
VOCAB_MAPPING = {}
def get_vocab(config: configure_pretraining.PretrainingConfig):
"""Memoized load of the vocab file."""
if config.vocab_file not in VOCAB_MAPPING:
vocab = tokenization.FullTokenizer(
config.vocab_file, do_lower_case=True).vocab
VOCAB_MAPPING[config.vocab_file] = vocab
return VOCAB_MAPPING[config.vocab_file]
def get_candidates_mask(config: configure_pretraining.PretrainingConfig,
inputs: pretrain_data.Inputs,
disallow_from_mask=None):
"""Returns a mask tensor of positions in the input that can be masked out."""
vocab = get_vocab(config)
ignore_ids = [vocab["[SEP]"], vocab["[CLS]"], vocab["[MASK]"]]
candidates_mask = tf.ones_like(inputs.input_ids, tf.bool)
for ignore_id in ignore_ids:
candidates_mask &= tf.not_equal(inputs.input_ids, ignore_id)
candidates_mask &= tf.cast(inputs.input_mask, tf.bool)
if disallow_from_mask is not None:
candidates_mask &= ~disallow_from_mask
return candidates_mask
def mask(config: configure_pretraining.PretrainingConfig,
inputs: pretrain_data.Inputs, mask_prob, proposal_distribution=1.0,
disallow_from_mask=None, already_masked=None):
"""Implementation of dynamic masking. The optional arguments aren't needed for
BERT/ELECTRA and are from early experiments in "strategically" masking out
tokens instead of uniformly at random.
Args:
config: configure_pretraining.PretrainingConfig
inputs: pretrain_data.Inputs containing input input_ids/input_mask
mask_prob: percent of tokens to mask
proposal_distribution: for non-uniform masking can be a [B, L] tensor
of scores for masking each position.
disallow_from_mask: a boolean tensor of [B, L] of positions that should
not be masked out
already_masked: a boolean tensor of [B, N] of already masked-out tokens
for multiple rounds of masking
Returns: a pretrain_data.Inputs with masking added
"""
# Get the batch size, sequence length, and max masked-out tokens
N = config.max_predictions_per_seq
B, L = modeling.get_shape_list(inputs.input_ids)
# Find indices where masking out a token is allowed
vocab = get_vocab(config)
candidates_mask = get_candidates_mask(config, inputs, disallow_from_mask)
# Set the number of tokens to mask out per example
num_tokens = tf.cast(tf.reduce_sum(inputs.input_mask, -1), tf.float32)
num_to_predict = tf.maximum(1, tf.minimum(
N, tf.cast(tf.round(num_tokens * mask_prob), tf.int32)))
masked_lm_weights = tf.cast(tf.sequence_mask(num_to_predict, N), tf.float32)
if already_masked is not None:
masked_lm_weights *= (1 - already_masked)
# Get a probability of masking each position in the sequence
candidate_mask_float = tf.cast(candidates_mask, tf.float32)
sample_prob = (proposal_distribution * candidate_mask_float)
sample_prob /= tf.reduce_sum(sample_prob, axis=-1, keepdims=True)
# Sample the positions to mask out
sample_prob = tf.stop_gradient(sample_prob)
sample_logits = tf.log(sample_prob)
masked_lm_positions = tf.random.categorical(
sample_logits, N, dtype=tf.int32)
masked_lm_positions *= tf.cast(masked_lm_weights, tf.int32)
# Get the ids of the masked-out tokens
shift = tf.expand_dims(L * tf.range(B), -1)
flat_positions = tf.reshape(masked_lm_positions + shift, [-1, 1])
masked_lm_ids = tf.gather_nd(tf.reshape(inputs.input_ids, [-1]),
flat_positions)
masked_lm_ids = tf.reshape(masked_lm_ids, [B, -1])
masked_lm_ids *= tf.cast(masked_lm_weights, tf.int32)
# Update the input ids
replace_with_mask_positions = masked_lm_positions * tf.cast(
tf.less(tf.random.uniform([B, N]), 0.85), tf.int32)
inputs_ids, _ = scatter_update(
inputs.input_ids, tf.fill([B, N], vocab["[MASK]"]),
replace_with_mask_positions)
return pretrain_data.get_updated_inputs(
inputs,
input_ids=tf.stop_gradient(inputs_ids),
masked_lm_positions=masked_lm_positions,
masked_lm_ids=masked_lm_ids,
masked_lm_weights=masked_lm_weights
)
def unmask(inputs: pretrain_data.Inputs):
unmasked_input_ids, _ = scatter_update(
inputs.input_ids, inputs.masked_lm_ids, inputs.masked_lm_positions)
return pretrain_data.get_updated_inputs(inputs, input_ids=unmasked_input_ids)
def sample_from_softmax(logits, disallow=None):
if disallow is not None:
logits -= 1000.0 * disallow
uniform_noise = tf.random.uniform(
modeling.get_shape_list(logits), minval=0, maxval=1)
gumbel_noise = -tf.log(-tf.log(uniform_noise + 1e-9) + 1e-9)
return tf.one_hot(tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1,
output_type=tf.int32), logits.shape[-1])