from os import path import librosa as rosa import matplotlib import matplotlib.pyplot as plt import numpy as np from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities import rank_zero_only from utils.stft import STFTMag matplotlib.use('Agg') class TensorBoardLoggerExpanded(TensorBoardLogger): def __init__(self, sr=16000): super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='') self.sr = sr self.stftmag = STFTMag() def fig2np(self, fig): data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) return data def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step): name_list = ['y', 'y_low', 'y_recon'] fig = plt.figure(figsize=(9, 15)) fig.suptitle(f'Epoch_{step}') for i, yy in enumerate([y, y_low, y_recon]): if yy.dim() == 1: yy = self.stftmag(yy) ax = plt.subplot(3, 1, i + 1) ax.set_title(name_list[i]) plt.imshow(rosa.amplitude_to_db(yy.numpy(), ref=np.max, top_db=80.), # vmin = -20, vmax=0., aspect='auto', origin='lower', interpolation='none') plt.colorbar() plt.xlabel('Frames') plt.ylabel('Channels') plt.tight_layout() fig.canvas.draw() data = self.fig2np(fig) plt.close() return data @rank_zero_only def log_spectrogram(self, y, y_low, y_recon, epoch): y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu() spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch) self.experiment.add_image(path.join(self.save_dir, 'result'), spec_img, epoch, dataformats='HWC') self.experiment.flush() return @rank_zero_only def log_audio(self, y, y_low, y_recon, epoch): y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(), name_list = ['y', 'y_low', 'y_recon'] for n, yy in zip(name_list, [y, y_low, y_recon]): self.experiment.add_audio(n, yy, epoch, self.sr) self.experiment.flush() return