# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Adversarial losses for text models.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function # Dependency imports from six.moves import xrange import tensorflow as tf flags = tf.app.flags FLAGS = flags.FLAGS # Adversarial and virtual adversarial training parameters. flags.DEFINE_float('perturb_norm_length', 5.0, 'Norm length of adversarial perturbation to be ' 'optimized with validation. ' '5.0 is optimal on IMDB with virtual adversarial training. ') # Virtual adversarial training parameters flags.DEFINE_integer('num_power_iteration', 1, 'The number of power iteration') flags.DEFINE_float('small_constant_for_finite_diff', 1e-1, 'Small constant for finite difference method') # Parameters for building the graph flags.DEFINE_string('adv_training_method', None, 'The flag which specifies training method. ' '"" : non-adversarial training (e.g. for running the ' ' semi-supervised sequence learning model) ' '"rp" : random perturbation training ' '"at" : adversarial training ' '"vat" : virtual adversarial training ' '"atvat" : at + vat ') flags.DEFINE_float('adv_reg_coeff', 1.0, 'Regularization coefficient of adversarial loss.') def random_perturbation_loss(embedded, length, loss_fn): """Adds noise to embeddings and recomputes classification loss.""" noise = tf.random_normal(shape=tf.shape(embedded)) perturb = _scale_l2(_mask_by_length(noise, length), FLAGS.perturb_norm_length) return loss_fn(embedded + perturb) def adversarial_loss(embedded, loss, loss_fn): """Adds gradient to embedding and recomputes classification loss.""" grad, = tf.gradients( loss, embedded, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) grad = tf.stop_gradient(grad) perturb = _scale_l2(grad, FLAGS.perturb_norm_length) return loss_fn(embedded + perturb) def virtual_adversarial_loss(logits, embedded, inputs, logits_from_embedding_fn): """Virtual adversarial loss. Computes virtual adversarial perturbation by finite difference method and power iteration, adds it to the embedding, and computes the KL divergence between the new logits and the original logits. Args: logits: 3-D float Tensor, [batch_size, num_timesteps, m], where m=1 if num_classes=2, otherwise m=num_classes. embedded: 3-D float Tensor, [batch_size, num_timesteps, embedding_dim]. inputs: VatxtInput. logits_from_embedding_fn: callable that takes embeddings and returns classifier logits. Returns: kl: float scalar. """ # Stop gradient of logits. See https://arxiv.org/abs/1507.00677 for details. logits = tf.stop_gradient(logits) # Only care about the KL divergence on the final timestep. weights = inputs.eos_weights assert weights is not None if FLAGS.single_label: indices = tf.stack([tf.range(FLAGS.batch_size), inputs.length - 1], 1) weights = tf.expand_dims(tf.gather_nd(inputs.eos_weights, indices), 1) # Initialize perturbation with random noise. # shape(embedded) = (batch_size, num_timesteps, embedding_dim) d = tf.random_normal(shape=tf.shape(embedded)) # Perform finite difference method and power iteration. # See Eq.(8) in the paper http://arxiv.org/pdf/1507.00677.pdf, # Adding small noise to input and taking gradient with respect to the noise # corresponds to 1 power iteration. for _ in xrange(FLAGS.num_power_iteration): d = _scale_l2( _mask_by_length(d, inputs.length), FLAGS.small_constant_for_finite_diff) d_logits = logits_from_embedding_fn(embedded + d) kl = _kl_divergence_with_logits(logits, d_logits, weights) d, = tf.gradients( kl, d, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) d = tf.stop_gradient(d) perturb = _scale_l2(d, FLAGS.perturb_norm_length) vadv_logits = logits_from_embedding_fn(embedded + perturb) return _kl_divergence_with_logits(logits, vadv_logits, weights) def random_perturbation_loss_bidir(embedded, length, loss_fn): """Adds noise to embeddings and recomputes classification loss.""" noise = [tf.random_normal(shape=tf.shape(emb)) for emb in embedded] masked = [_mask_by_length(n, length) for n in noise] scaled = [_scale_l2(m, FLAGS.perturb_norm_length) for m in masked] return loss_fn([e + s for (e, s) in zip(embedded, scaled)]) def adversarial_loss_bidir(embedded, loss, loss_fn): """Adds gradient to embeddings and recomputes classification loss.""" grads = tf.gradients( loss, embedded, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) adv_exs = [ emb + _scale_l2(tf.stop_gradient(g), FLAGS.perturb_norm_length) for emb, g in zip(embedded, grads) ] return loss_fn(adv_exs) def virtual_adversarial_loss_bidir(logits, embedded, inputs, logits_from_embedding_fn): """Virtual adversarial loss for bidirectional models.""" logits = tf.stop_gradient(logits) f_inputs, _ = inputs weights = f_inputs.eos_weights if FLAGS.single_label: indices = tf.stack([tf.range(FLAGS.batch_size), f_inputs.length - 1], 1) weights = tf.expand_dims(tf.gather_nd(f_inputs.eos_weights, indices), 1) assert weights is not None perturbs = [ _mask_by_length(tf.random_normal(shape=tf.shape(emb)), f_inputs.length) for emb in embedded ] for _ in xrange(FLAGS.num_power_iteration): perturbs = [ _scale_l2(d, FLAGS.small_constant_for_finite_diff) for d in perturbs ] d_logits = logits_from_embedding_fn( [emb + d for (emb, d) in zip(embedded, perturbs)]) kl = _kl_divergence_with_logits(logits, d_logits, weights) perturbs = tf.gradients( kl, perturbs, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) perturbs = [tf.stop_gradient(d) for d in perturbs] perturbs = [_scale_l2(d, FLAGS.perturb_norm_length) for d in perturbs] vadv_logits = logits_from_embedding_fn( [emb + d for (emb, d) in zip(embedded, perturbs)]) return _kl_divergence_with_logits(logits, vadv_logits, weights) def _mask_by_length(t, length): """Mask t, 3-D [batch, time, dim], by length, 1-D [batch,].""" maxlen = t.get_shape().as_list()[1] # Subtract 1 from length to prevent the perturbation from going on 'eos' mask = tf.sequence_mask(length - 1, maxlen=maxlen) mask = tf.expand_dims(tf.cast(mask, tf.float32), -1) # shape(mask) = (batch, num_timesteps, 1) return t * mask def _scale_l2(x, norm_length): # shape(x) = (batch, num_timesteps, d) # Divide x by max(abs(x)) for a numerically stable L2 norm. # 2norm(x) = a * 2norm(x/a) # Scale over the full sequence, dims (1, 2) alpha = tf.reduce_max(tf.abs(x), (1, 2), keep_dims=True) + 1e-12 l2_norm = alpha * tf.sqrt( tf.reduce_sum(tf.pow(x / alpha, 2), (1, 2), keep_dims=True) + 1e-6) x_unit = x / l2_norm return norm_length * x_unit def _kl_divergence_with_logits(q_logits, p_logits, weights): """Returns weighted KL divergence between distributions q and p. Args: q_logits: logits for 1st argument of KL divergence shape [batch_size, num_timesteps, num_classes] if num_classes > 2, and [batch_size, num_timesteps] if num_classes == 2. p_logits: logits for 2nd argument of KL divergence with same shape q_logits. weights: 1-D float tensor with shape [batch_size, num_timesteps]. Elements should be 1.0 only on end of sequences Returns: KL: float scalar. """ # For logistic regression if FLAGS.num_classes == 2: q = tf.nn.sigmoid(q_logits) kl = (-tf.nn.sigmoid_cross_entropy_with_logits(logits=q_logits, labels=q) + tf.nn.sigmoid_cross_entropy_with_logits(logits=p_logits, labels=q)) kl = tf.squeeze(kl, 2) # For softmax regression else: q = tf.nn.softmax(q_logits) kl = tf.reduce_sum( q * (tf.nn.log_softmax(q_logits) - tf.nn.log_softmax(p_logits)), -1) num_labels = tf.reduce_sum(weights) num_labels = tf.where(tf.equal(num_labels, 0.), 1., num_labels) kl.get_shape().assert_has_rank(2) weights.get_shape().assert_has_rank(2) loss = tf.identity(tf.reduce_sum(weights * kl) / num_labels, name='kl') return loss