Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utils for plotting and summarizing. | |
""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import matplotlib.gridspec as gridspec | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import scipy | |
import tensorflow as tf | |
import models | |
def summarize_ess(weights, only_last_timestep=False): | |
"""Plots the effective sample size. | |
Args: | |
weights: List of length num_timesteps Tensors of shape | |
[num_samples, batch_size] | |
""" | |
num_timesteps = len(weights) | |
batch_size = tf.cast(tf.shape(weights[0])[1], dtype=tf.float64) | |
for i in range(num_timesteps): | |
if only_last_timestep and i < num_timesteps-1: continue | |
w = tf.nn.softmax(weights[i], dim=0) | |
centered_weights = w - tf.reduce_mean(w, axis=0, keepdims=True) | |
variance = tf.reduce_sum(tf.square(centered_weights))/(batch_size-1) | |
ess = 1./tf.reduce_mean(tf.reduce_sum(tf.square(w), axis=0)) | |
tf.summary.scalar("ess/%d" % i, ess) | |
tf.summary.scalar("ese/%d" % i, ess / batch_size) | |
tf.summary.scalar("weight_variance/%d" % i, variance) | |
def summarize_particles(states, weights, observation, model): | |
"""Plots particle locations and weights. | |
Args: | |
states: List of length num_timesteps Tensors of shape | |
[batch_size*num_particles, state_size]. | |
weights: List of length num_timesteps Tensors of shape [num_samples, | |
batch_size] | |
observation: Tensor of shape [batch_size*num_samples, state_size] | |
""" | |
num_timesteps = len(weights) | |
num_samples, batch_size = weights[0].get_shape().as_list() | |
# get q0 information for plotting | |
q0_dist = model.q.q_zt(observation, tf.zeros_like(states[0]), 0) | |
q0_loc = q0_dist.loc[0:batch_size, 0] | |
q0_scale = q0_dist.scale[0:batch_size, 0] | |
# get posterior information for plotting | |
post = (model.p.mixing_coeff, model.p.prior_mode_mean, model.p.variance, | |
tf.reduce_sum(model.p.bs), model.p.num_timesteps) | |
# Reshape states and weights to be [time, num_samples, batch_size] | |
states = tf.stack(states) | |
weights = tf.stack(weights) | |
# normalize the weights over the sample dimension | |
weights = tf.nn.softmax(weights, dim=1) | |
states = tf.reshape(states, tf.shape(weights)) | |
ess = 1./tf.reduce_sum(tf.square(weights), axis=1) | |
def _plot_states(states_batch, weights_batch, observation_batch, ess_batch, q0, post): | |
""" | |
states: [time, num_samples, batch_size] | |
weights [time, num_samples, batch_size] | |
observation: [batch_size, 1] | |
q0: ([batch_size], [batch_size]) | |
post: ... | |
""" | |
num_timesteps, _, batch_size = states_batch.shape | |
plots = [] | |
for i in range(batch_size): | |
states = states_batch[:,:,i] | |
weights = weights_batch[:,:,i] | |
observation = observation_batch[i] | |
ess = ess_batch[:,i] | |
q0_loc = q0[0][i] | |
q0_scale = q0[1][i] | |
fig = plt.figure(figsize=(7, (num_timesteps + 1) * 2)) | |
# Each timestep gets two plots -- a bar plot and a histogram of state locs. | |
# The bar plot will be bar_rows rows tall. | |
# The histogram will be 1 row tall. | |
# There is also 1 extra plot at the top showing the posterior and q. | |
bar_rows = 8 | |
num_rows = (num_timesteps + 1) * (bar_rows + 1) | |
gs = gridspec.GridSpec(num_rows, 1) | |
# Figure out how wide to make the plot | |
prior_lims = (post[1] * -2, post[1] * 2) | |
q_lims = (scipy.stats.norm.ppf(0.01, loc=q0_loc, scale=q0_scale), | |
scipy.stats.norm.ppf(0.99, loc=q0_loc, scale=q0_scale)) | |
state_width = states.max() - states.min() | |
state_lims = (states.min() - state_width * 0.15, | |
states.max() + state_width * 0.15) | |
lims = (min(prior_lims[0], q_lims[0], state_lims[0]), | |
max(prior_lims[1], q_lims[1], state_lims[1])) | |
# plot the posterior | |
z0 = np.arange(lims[0], lims[1], 0.1) | |
alpha, pos_mu, sigma_sq, B, T = post | |
neg_mu = -pos_mu | |
scale = np.sqrt((T + 1) * sigma_sq) | |
p_zn = ( | |
alpha * scipy.stats.norm.pdf( | |
observation, loc=pos_mu + B, scale=scale) + (1 - alpha) * | |
scipy.stats.norm.pdf(observation, loc=neg_mu + B, scale=scale)) | |
p_z0 = ( | |
alpha * scipy.stats.norm.pdf(z0, loc=pos_mu, scale=np.sqrt(sigma_sq)) | |
+ (1 - alpha) * scipy.stats.norm.pdf( | |
z0, loc=neg_mu, scale=np.sqrt(sigma_sq))) | |
p_zn_given_z0 = scipy.stats.norm.pdf( | |
observation, loc=z0 + B, scale=np.sqrt(T * sigma_sq)) | |
post_z0 = (p_z0 * p_zn_given_z0) / p_zn | |
# plot q | |
q_z0 = scipy.stats.norm.pdf(z0, loc=q0_loc, scale=q0_scale) | |
ax = plt.subplot(gs[0:bar_rows, :]) | |
ax.plot(z0, q_z0, color="blue") | |
ax.plot(z0, post_z0, color="green") | |
ax.plot(z0, p_z0, color="red") | |
ax.legend(("q", "posterior", "prior"), loc="best", prop={"size": 10}) | |
ax.set_xticks([]) | |
ax.set_xlim(*lims) | |
# plot the states | |
for t in range(num_timesteps): | |
start = (t + 1) * (bar_rows + 1) | |
ax1 = plt.subplot(gs[start:start + bar_rows, :]) | |
ax2 = plt.subplot(gs[start + bar_rows:start + bar_rows + 1, :]) | |
# plot the states barplot | |
# ax1.hist( | |
# states[t, :], | |
# weights=weights[t, :], | |
# bins=50, | |
# edgecolor="none", | |
# alpha=0.2) | |
ax1.bar(states[t,:], weights[t,:], width=0.02, alpha=0.2, edgecolor = "none") | |
ax1.set_ylabel("t=%d" % t) | |
ax1.set_xticks([]) | |
ax1.grid(True, which="both") | |
ax1.set_xlim(*lims) | |
# plot the observation | |
ax1.axvline(x=observation, color="red", linestyle="dashed") | |
# add the ESS | |
ax1.text(0.1, 0.9, "ESS: %0.2f" % ess[t], | |
ha='center', va='center', transform=ax1.transAxes) | |
# plot the state location histogram | |
ax2.hist2d( | |
states[t, :], np.zeros_like(states[t, :]), bins=[50, 1], cmap="Greys") | |
ax2.grid(False) | |
ax2.set_yticks([]) | |
ax2.set_xlim(*lims) | |
if t != num_timesteps - 1: | |
ax2.set_xticks([]) | |
fig.canvas.draw() | |
p = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
plots.append(p.reshape(fig.canvas.get_width_height()[::-1] + (3,))) | |
plt.close(fig) | |
return np.stack(plots) | |
plots = tf.py_func(_plot_states, | |
[states, weights, observation, ess, (q0_loc, q0_scale), post], | |
[tf.uint8])[0] | |
tf.summary.image("states", plots, 5, collections=["infrequent_summaries"]) | |
def plot_weights(weights, resampled=None): | |
"""Plots the weights and effective sample size from an SMC rollout. | |
Args: | |
weights: [num_timesteps, num_samples, batch_size] importance weights | |
resampled: [num_timesteps] 0/1 indicating if resampling ocurred | |
""" | |
weights = tf.convert_to_tensor(weights) | |
def _make_plots(weights, resampled): | |
num_timesteps, num_samples, batch_size = weights.shape | |
plots = [] | |
for i in range(batch_size): | |
fig, axes = plt.subplots(nrows=1, sharex=True, figsize=(8, 4)) | |
axes.stackplot(np.arange(num_timesteps), np.transpose(weights[:, :, i])) | |
axes.set_title("Weights") | |
axes.set_xlabel("Steps") | |
axes.set_ylim([0, 1]) | |
axes.set_xlim([0, num_timesteps - 1]) | |
for j in np.where(resampled > 0)[0]: | |
axes.axvline(x=j, color="red", linestyle="dashed", ymin=0.0, ymax=1.0) | |
fig.canvas.draw() | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
plots.append(data) | |
plt.close(fig) | |
return np.stack(plots, axis=0) | |
if resampled is None: | |
num_timesteps, _, batch_size = weights.get_shape().as_list() | |
resampled = tf.zeros([num_timesteps], dtype=tf.float32) | |
plots = tf.py_func(_make_plots, | |
[tf.nn.softmax(weights, dim=1), | |
tf.to_float(resampled)], [tf.uint8])[0] | |
batch_size = weights.get_shape().as_list()[-1] | |
tf.summary.image( | |
"weights", plots, batch_size, collections=["infrequent_summaries"]) | |
def summarize_weights(weights, num_timesteps, num_samples): | |
# weights is [num_timesteps, num_samples, batch_size] | |
weights = tf.convert_to_tensor(weights) | |
mean = tf.reduce_mean(weights, axis=1, keepdims=True) | |
squared_diff = tf.square(weights - mean) | |
variances = tf.reduce_sum(squared_diff, axis=1) / (num_samples - 1) | |
# average the variance over the batch | |
variances = tf.reduce_mean(variances, axis=1) | |
avg_magnitude = tf.reduce_mean(tf.abs(weights), axis=[1, 2]) | |
for t in xrange(num_timesteps): | |
tf.summary.scalar("weights/variance_%d" % t, variances[t]) | |
tf.summary.scalar("weights/magnitude_%d" % t, avg_magnitude[t]) | |
tf.summary.histogram("weights/step_%d" % t, weights[t]) | |
def summarize_learning_signal(rewards, tag): | |
num_resampling_events, _ = rewards.get_shape().as_list() | |
mean = tf.reduce_mean(rewards, axis=1) | |
avg_magnitude = tf.reduce_mean(tf.abs(rewards), axis=1) | |
reward_square = tf.reduce_mean(tf.square(rewards), axis=1) | |
for t in xrange(num_resampling_events): | |
tf.summary.scalar("%s/mean_%d" % (tag, t), mean[t]) | |
tf.summary.scalar("%s/magnitude_%d" % (tag, t), avg_magnitude[t]) | |
tf.summary.scalar("%s/squared_%d" % (tag, t), reward_square[t]) | |
tf.summary.histogram("%s/step_%d" % (tag, t), rewards[t]) | |
def summarize_qs(model, observation, states): | |
model.q.summarize_weights() | |
if hasattr(model.p, "posterior") and callable(getattr(model.p, "posterior")): | |
states = [tf.zeros_like(states[0])] + states[:-1] | |
for t, prev_state in enumerate(states): | |
p = model.p.posterior(observation, prev_state, t) | |
q = model.q.q_zt(observation, prev_state, t) | |
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(p, q)) | |
tf.summary.scalar("kl_q/%d" % t, tf.reduce_mean(kl)) | |
mean_diff = q.loc - p.loc | |
mean_abs_err = tf.abs(mean_diff) | |
mean_rel_err = tf.abs(mean_diff / p.loc) | |
tf.summary.scalar("q_mean_convergence/absolute_error_%d" % t, | |
tf.reduce_mean(mean_abs_err)) | |
tf.summary.scalar("q_mean_convergence/relative_error_%d" % t, | |
tf.reduce_mean(mean_rel_err)) | |
sigma_diff = tf.square(q.scale) - tf.square(p.scale) | |
sigma_abs_err = tf.abs(sigma_diff) | |
sigma_rel_err = tf.abs(sigma_diff / tf.square(p.scale)) | |
tf.summary.scalar("q_variance_convergence/absolute_error_%d" % t, | |
tf.reduce_mean(sigma_abs_err)) | |
tf.summary.scalar("q_variance_convergence/relative_error_%d" % t, | |
tf.reduce_mean(sigma_rel_err)) | |
def summarize_rs(model, states): | |
model.r.summarize_weights() | |
for t, state in enumerate(states): | |
true_r = model.p.lookahead(state, t) | |
r = model.r.r_xn(state, t) | |
kl = tf.reduce_mean(tf.contrib.distributions.kl_divergence(true_r, r)) | |
tf.summary.scalar("kl_r/%d" % t, tf.reduce_mean(kl)) | |
mean_diff = true_r.loc - r.loc | |
mean_abs_err = tf.abs(mean_diff) | |
mean_rel_err = tf.abs(mean_diff / true_r.loc) | |
tf.summary.scalar("r_mean_convergence/absolute_error_%d" % t, | |
tf.reduce_mean(mean_abs_err)) | |
tf.summary.scalar("r_mean_convergence/relative_error_%d" % t, | |
tf.reduce_mean(mean_rel_err)) | |
sigma_diff = tf.square(r.scale) - tf.square(true_r.scale) | |
sigma_abs_err = tf.abs(sigma_diff) | |
sigma_rel_err = tf.abs(sigma_diff / tf.square(true_r.scale)) | |
tf.summary.scalar("r_variance_convergence/absolute_error_%d" % t, | |
tf.reduce_mean(sigma_abs_err)) | |
tf.summary.scalar("r_variance_convergence/relative_error_%d" % t, | |
tf.reduce_mean(sigma_rel_err)) | |
def summarize_model(model, true_bs, observation, states, bound, summarize_r=True): | |
if hasattr(model.p, "bs"): | |
model_b = tf.reduce_sum(model.p.bs, axis=0) | |
true_b = tf.reduce_sum(true_bs, axis=0) | |
abs_err = tf.abs(model_b - true_b) | |
rel_err = abs_err / true_b | |
tf.summary.scalar("sum_of_bs/data_generating_process", tf.reduce_mean(true_b)) | |
tf.summary.scalar("sum_of_bs/model", tf.reduce_mean(model_b)) | |
tf.summary.scalar("sum_of_bs/absolute_error", tf.reduce_mean(abs_err)) | |
tf.summary.scalar("sum_of_bs/relative_error", tf.reduce_mean(rel_err)) | |
#summarize_qs(model, observation, states) | |
#if bound == "fivo-aux" and summarize_r: | |
# summarize_rs(model, states) | |
def summarize_grads(grads, loss_name): | |
grad_ema = tf.train.ExponentialMovingAverage(decay=0.99) | |
vectorized_grads = tf.concat( | |
[tf.reshape(g, [-1]) for g, _ in grads if g is not None], axis=0) | |
new_second_moments = tf.square(vectorized_grads) | |
new_first_moments = vectorized_grads | |
maintain_grad_ema_op = grad_ema.apply([new_first_moments, new_second_moments]) | |
first_moments = grad_ema.average(new_first_moments) | |
second_moments = grad_ema.average(new_second_moments) | |
variances = second_moments - tf.square(first_moments) | |
tf.summary.scalar("grad_variance/%s" % loss_name, tf.reduce_mean(variances)) | |
tf.summary.histogram("grad_variance/%s" % loss_name, variances) | |
tf.summary.histogram("grad_mean/%s" % loss_name, first_moments) | |
return maintain_grad_ema_op | |