|
import matplotlib |
|
matplotlib.use("Agg") |
|
import matplotlib.pylab as plt |
|
import numpy as np |
|
|
|
|
|
def save_figure_to_numpy(fig): |
|
|
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
return data |
|
|
|
|
|
def plot_alignment_to_numpy(alignment, info=None): |
|
fig, ax = plt.subplots(figsize=(6, 4)) |
|
im = ax.imshow(alignment, aspect='auto', origin='lower', |
|
interpolation='none') |
|
fig.colorbar(im, ax=ax) |
|
xlabel = 'Decoder timestep' |
|
if info is not None: |
|
xlabel += '\n\n' + info |
|
plt.xlabel(xlabel) |
|
plt.ylabel('Encoder timestep') |
|
plt.tight_layout() |
|
|
|
fig.canvas.draw() |
|
data = save_figure_to_numpy(fig) |
|
plt.close() |
|
return data |
|
|
|
|
|
def plot_spectrogram_to_numpy(spectrogram): |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
im = ax.imshow(spectrogram, aspect="auto", origin="lower", |
|
interpolation='none') |
|
plt.colorbar(im, ax=ax) |
|
plt.xlabel("Frames") |
|
plt.ylabel("Channels") |
|
plt.tight_layout() |
|
|
|
fig.canvas.draw() |
|
data = save_figure_to_numpy(fig) |
|
plt.close() |
|
return data |
|
|
|
|
|
def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5, |
|
color='green', marker='+', s=1, label='target') |
|
ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5, |
|
color='red', marker='.', s=1, label='predicted') |
|
|
|
plt.xlabel("Frames (Green target, Red predicted)") |
|
plt.ylabel("Gate State") |
|
plt.tight_layout() |
|
|
|
fig.canvas.draw() |
|
data = save_figure_to_numpy(fig) |
|
plt.close() |
|
return data |
|
|