import matplotlib matplotlib.use("Agg") import matplotlib.pylab as plt import numpy as np def save_figure_to_numpy(fig): # save it to a numpy array. data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data def plot_alignment_to_numpy(alignment, info=None): fig, ax = plt.subplots(figsize=(6, 4)) im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none') fig.colorbar(im, ax=ax) xlabel = 'Decoder timestep' if info is not None: xlabel += '\n\n' + info plt.xlabel(xlabel) plt.ylabel('Encoder timestep') plt.tight_layout() fig.canvas.draw() data = save_figure_to_numpy(fig) plt.close() return data def plot_spectrogram_to_numpy(spectrogram): fig, ax = plt.subplots(figsize=(12, 3)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation='none') plt.colorbar(im, ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() data = save_figure_to_numpy(fig) plt.close() return data def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): fig, ax = plt.subplots(figsize=(12, 3)) ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, color='green', marker='+', s=1, label='target') ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=1, label='predicted') plt.xlabel("Frames (Green target, Red predicted)") plt.ylabel("Gate State") plt.tight_layout() fig.canvas.draw() data = save_figure_to_numpy(fig) plt.close() return data