Spaces:
Running
Running
# Copyright 2017 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Model loss construction.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
# Dependency imports | |
import numpy as np | |
from six.moves import xrange | |
import tensorflow as tf | |
# Useful for REINFORCE baseline. | |
from losses import losses | |
FLAGS = tf.app.flags.FLAGS | |
def create_dis_loss(fake_predictions, real_predictions, targets_present): | |
"""Compute Discriminator loss across real/fake.""" | |
missing = tf.cast(targets_present, tf.int32) | |
missing = 1 - missing | |
missing = tf.cast(missing, tf.bool) | |
real_labels = tf.ones([FLAGS.batch_size, FLAGS.sequence_length]) | |
dis_loss_real = tf.losses.sigmoid_cross_entropy( | |
real_labels, real_predictions, weights=missing) | |
dis_loss_fake = tf.losses.sigmoid_cross_entropy( | |
targets_present, fake_predictions, weights=missing) | |
dis_loss = (dis_loss_fake + dis_loss_real) / 2. | |
return dis_loss, dis_loss_fake, dis_loss_real | |
def create_critic_loss(cumulative_rewards, estimated_values, present): | |
"""Compute Critic loss in estimating the value function. This should be an | |
estimate only for the missing elements.""" | |
missing = tf.cast(present, tf.int32) | |
missing = 1 - missing | |
missing = tf.cast(missing, tf.bool) | |
loss = tf.losses.mean_squared_error( | |
labels=cumulative_rewards, predictions=estimated_values, weights=missing) | |
return loss | |
def create_masked_cross_entropy_loss(targets, present, logits): | |
"""Calculate the cross entropy loss matrices for the masked tokens.""" | |
cross_entropy_losses = losses.cross_entropy_loss_matrix(targets, logits) | |
# Zeros matrix. | |
zeros_losses = tf.zeros( | |
shape=[FLAGS.batch_size, FLAGS.sequence_length], dtype=tf.float32) | |
missing_ce_loss = tf.where(present, zeros_losses, cross_entropy_losses) | |
return missing_ce_loss | |
def calculate_reinforce_objective(hparams, | |
log_probs, | |
dis_predictions, | |
present, | |
estimated_values=None): | |
"""Calculate the REINFORCE objectives. The REINFORCE objective should | |
only be on the tokens that were missing. Specifically, the final Generator | |
reward should be based on the Discriminator predictions on missing tokens. | |
The log probaibilities should be only for missing tokens and the baseline | |
should be calculated only on the missing tokens. | |
For this model, we optimize the reward is the log of the *conditional* | |
probability the Discriminator assigns to the distribution. Specifically, for | |
a Discriminator D which outputs probability of real, given the past context, | |
r_t = log D(x_t|x_0,x_1,...x_{t-1}) | |
And the policy for Generator G is the log-probability of taking action x2 | |
given the past context. | |
Args: | |
hparams: MaskGAN hyperparameters. | |
log_probs: tf.float32 Tensor of log probailities of the tokens selected by | |
the Generator. Shape [batch_size, sequence_length]. | |
dis_predictions: tf.float32 Tensor of the predictions from the | |
Discriminator. Shape [batch_size, sequence_length]. | |
present: tf.bool Tensor indicating which tokens are present. Shape | |
[batch_size, sequence_length]. | |
estimated_values: tf.float32 Tensor of estimated state values of tokens. | |
Shape [batch_size, sequence_length] | |
Returns: | |
final_gen_objective: Final REINFORCE objective for the sequence. | |
rewards: tf.float32 Tensor of rewards for sequence of shape [batch_size, | |
sequence_length] | |
advantages: tf.float32 Tensor of advantages for sequence of shape | |
[batch_size, sequence_length] | |
baselines: tf.float32 Tensor of baselines for sequence of shape | |
[batch_size, sequence_length] | |
maintain_averages_op: ExponentialMovingAverage apply average op to | |
maintain the baseline. | |
""" | |
# Final Generator objective. | |
final_gen_objective = 0. | |
gamma = hparams.rl_discount_rate | |
eps = 1e-7 | |
# Generator rewards are log-probabilities. | |
eps = tf.constant(1e-7, tf.float32) | |
dis_predictions = tf.nn.sigmoid(dis_predictions) | |
rewards = tf.log(dis_predictions + eps) | |
# Apply only for missing elements. | |
zeros = tf.zeros_like(present, dtype=tf.float32) | |
log_probs = tf.where(present, zeros, log_probs) | |
rewards = tf.where(present, zeros, rewards) | |
# Unstack Tensors into lists. | |
rewards_list = tf.unstack(rewards, axis=1) | |
log_probs_list = tf.unstack(log_probs, axis=1) | |
missing = 1. - tf.cast(present, tf.float32) | |
missing_list = tf.unstack(missing, axis=1) | |
# Cumulative Discounted Returns. The true value function V*(s). | |
cumulative_rewards = [] | |
for t in xrange(FLAGS.sequence_length): | |
cum_value = tf.zeros(shape=[FLAGS.batch_size]) | |
for s in xrange(t, FLAGS.sequence_length): | |
cum_value += missing_list[s] * np.power(gamma, (s - t)) * rewards_list[s] | |
cumulative_rewards.append(cum_value) | |
cumulative_rewards = tf.stack(cumulative_rewards, axis=1) | |
## REINFORCE with different baselines. | |
# We create a separate critic functionality for the Discriminator. This | |
# will need to operate unidirectionally and it may take in the past context. | |
if FLAGS.baseline_method == 'critic': | |
# Critic loss calculated from the estimated value function \hat{V}(s) | |
# versus the true value function V*(s). | |
critic_loss = create_critic_loss(cumulative_rewards, estimated_values, | |
present) | |
# Baselines are coming from the critic's estimated state values. | |
baselines = tf.unstack(estimated_values, axis=1) | |
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s). | |
advantages = [] | |
for t in xrange(FLAGS.sequence_length): | |
log_probability = log_probs_list[t] | |
cum_advantage = tf.zeros(shape=[FLAGS.batch_size]) | |
for s in xrange(t, FLAGS.sequence_length): | |
cum_advantage += missing_list[s] * np.power(gamma, | |
(s - t)) * rewards_list[s] | |
cum_advantage -= baselines[t] | |
# Clip advantages. | |
cum_advantage = tf.clip_by_value(cum_advantage, -FLAGS.advantage_clipping, | |
FLAGS.advantage_clipping) | |
advantages.append(missing_list[t] * cum_advantage) | |
final_gen_objective += tf.multiply( | |
log_probability, missing_list[t] * tf.stop_gradient(cum_advantage)) | |
maintain_averages_op = None | |
baselines = tf.stack(baselines, axis=1) | |
advantages = tf.stack(advantages, axis=1) | |
# Split the batch into half. Use half for MC estimates for REINFORCE. | |
# Use the other half to establish a baseline. | |
elif FLAGS.baseline_method == 'dis_batch': | |
# TODO(liamfedus): Recheck. | |
[rewards_half, baseline_half] = tf.split( | |
rewards, num_or_size_splits=2, axis=0) | |
[log_probs_half, _] = tf.split(log_probs, num_or_size_splits=2, axis=0) | |
[reward_present_half, baseline_present_half] = tf.split( | |
present, num_or_size_splits=2, axis=0) | |
# Unstack to lists. | |
baseline_list = tf.unstack(baseline_half, axis=1) | |
baseline_missing = 1. - tf.cast(baseline_present_half, tf.float32) | |
baseline_missing_list = tf.unstack(baseline_missing, axis=1) | |
baselines = [] | |
for t in xrange(FLAGS.sequence_length): | |
# Calculate baseline only for missing tokens. | |
num_missing = tf.reduce_sum(baseline_missing_list[t]) | |
avg_baseline = tf.reduce_sum( | |
baseline_missing_list[t] * baseline_list[t], keep_dims=True) / ( | |
num_missing + eps) | |
baseline = tf.tile(avg_baseline, multiples=[FLAGS.batch_size / 2]) | |
baselines.append(baseline) | |
# Unstack to lists. | |
rewards_list = tf.unstack(rewards_half, axis=1) | |
log_probs_list = tf.unstack(log_probs_half, axis=1) | |
reward_missing = 1. - tf.cast(reward_present_half, tf.float32) | |
reward_missing_list = tf.unstack(reward_missing, axis=1) | |
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s). | |
advantages = [] | |
for t in xrange(FLAGS.sequence_length): | |
log_probability = log_probs_list[t] | |
cum_advantage = tf.zeros(shape=[FLAGS.batch_size / 2]) | |
for s in xrange(t, FLAGS.sequence_length): | |
cum_advantage += reward_missing_list[s] * np.power(gamma, (s - t)) * ( | |
rewards_list[s] - baselines[s]) | |
# Clip advantages. | |
cum_advantage = tf.clip_by_value(cum_advantage, -FLAGS.advantage_clipping, | |
FLAGS.advantage_clipping) | |
advantages.append(reward_missing_list[t] * cum_advantage) | |
final_gen_objective += tf.multiply( | |
log_probability, | |
reward_missing_list[t] * tf.stop_gradient(cum_advantage)) | |
# Cumulative Discounted Returns. The true value function V*(s). | |
cumulative_rewards = [] | |
for t in xrange(FLAGS.sequence_length): | |
cum_value = tf.zeros(shape=[FLAGS.batch_size / 2]) | |
for s in xrange(t, FLAGS.sequence_length): | |
cum_value += reward_missing_list[s] * np.power(gamma, ( | |
s - t)) * rewards_list[s] | |
cumulative_rewards.append(cum_value) | |
cumulative_rewards = tf.stack(cumulative_rewards, axis=1) | |
rewards = rewards_half | |
critic_loss = None | |
maintain_averages_op = None | |
baselines = tf.stack(baselines, axis=1) | |
advantages = tf.stack(advantages, axis=1) | |
# Exponential Moving Average baseline. | |
elif FLAGS.baseline_method == 'ema': | |
# TODO(liamfedus): Recheck. | |
# Lists of rewards and Log probabilities of the actions taken only for | |
# missing tokens. | |
ema = tf.train.ExponentialMovingAverage(decay=hparams.baseline_decay) | |
maintain_averages_op = ema.apply(rewards_list) | |
baselines = [] | |
for r in rewards_list: | |
baselines.append(ema.average(r)) | |
## Calculate the Advantages, A(s,a) = Q(s,a) - \hat{V}(s). | |
advantages = [] | |
for t in xrange(FLAGS.sequence_length): | |
log_probability = log_probs_list[t] | |
# Calculate the forward advantage only on the missing tokens. | |
cum_advantage = tf.zeros(shape=[FLAGS.batch_size]) | |
for s in xrange(t, FLAGS.sequence_length): | |
cum_advantage += missing_list[s] * np.power(gamma, (s - t)) * ( | |
rewards_list[s] - baselines[s]) | |
# Clip advantages. | |
cum_advantage = tf.clip_by_value(cum_advantage, -FLAGS.advantage_clipping, | |
FLAGS.advantage_clipping) | |
advantages.append(missing_list[t] * cum_advantage) | |
final_gen_objective += tf.multiply( | |
log_probability, missing_list[t] * tf.stop_gradient(cum_advantage)) | |
critic_loss = None | |
baselines = tf.stack(baselines, axis=1) | |
advantages = tf.stack(advantages, axis=1) | |
elif FLAGS.baseline_method is None: | |
num_missing = tf.reduce_sum(missing) | |
final_gen_objective += tf.reduce_sum(rewards) / (num_missing + eps) | |
baselines = tf.zeros_like(rewards) | |
critic_loss = None | |
maintain_averages_op = None | |
advantages = cumulative_rewards | |
else: | |
raise NotImplementedError | |
return [ | |
final_gen_objective, log_probs, rewards, advantages, baselines, | |
maintain_averages_op, critic_loss, cumulative_rewards | |
] | |
def calculate_log_perplexity(logits, targets, present): | |
"""Calculate the average log perplexity per *missing* token. | |
Args: | |
logits: tf.float32 Tensor of the logits of shape [batch_size, | |
sequence_length, vocab_size]. | |
targets: tf.int32 Tensor of the sequence target of shape [batch_size, | |
sequence_length]. | |
present: tf.bool Tensor indicating the presence or absence of the token | |
of shape [batch_size, sequence_length]. | |
Returns: | |
avg_log_perplexity: Scalar indicating the average log perplexity per | |
missing token in the batch. | |
""" | |
# logits = tf.Print(logits, [logits], message='logits:', summarize=50) | |
# targets = tf.Print(targets, [targets], message='targets:', summarize=50) | |
eps = 1e-12 | |
logits = tf.reshape(logits, [-1, FLAGS.vocab_size]) | |
# Only calculate log-perplexity on missing tokens. | |
weights = tf.cast(present, tf.float32) | |
weights = 1. - weights | |
weights = tf.reshape(weights, [-1]) | |
num_missing = tf.reduce_sum(weights) | |
log_perplexity = tf.contrib.legacy_seq2seq.sequence_loss_by_example( | |
[logits], [tf.reshape(targets, [-1])], [weights]) | |
avg_log_perplexity = tf.reduce_sum(log_perplexity) / (num_missing + eps) | |
return avg_log_perplexity | |