File size: 2,520 Bytes
e34c0af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from os import path

import librosa as rosa
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities import rank_zero_only

from utils.stft import STFTMag

matplotlib.use('Agg')


class TensorBoardLoggerExpanded(TensorBoardLogger):
    def __init__(self, sr=16000):
        super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='')
        self.sr = sr
        self.stftmag = STFTMag()

    def fig2np(self, fig):
        data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        return data

    def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step):
        name_list = ['y', 'y_low', 'y_recon']
        fig = plt.figure(figsize=(9, 15))
        fig.suptitle(f'Epoch_{step}')
        for i, yy in enumerate([y, y_low, y_recon]):
            if yy.dim() == 1:
                yy = self.stftmag(yy)
            ax = plt.subplot(3, 1, i + 1)
            ax.set_title(name_list[i])
            plt.imshow(rosa.amplitude_to_db(yy.numpy(),
                                            ref=np.max, top_db=80.),
                       # vmin = -20,
                       vmax=0.,
                       aspect='auto',
                       origin='lower',
                       interpolation='none')
            plt.colorbar()
            plt.xlabel('Frames')
            plt.ylabel('Channels')
            plt.tight_layout()

        fig.canvas.draw()
        data = self.fig2np(fig)

        plt.close()
        return data

    @rank_zero_only
    def log_spectrogram(self, y, y_low, y_recon, epoch):
        y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu()
        spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch)
        self.experiment.add_image(path.join(self.save_dir, 'result'),
                                  spec_img,
                                  epoch,
                                  dataformats='HWC')
        self.experiment.flush()
        return

    @rank_zero_only
    def log_audio(self, y, y_low, y_recon, epoch):
        y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(),
        name_list = ['y', 'y_low', 'y_recon']
        for n, yy in zip(name_list, [y, y_low, y_recon]):
            self.experiment.add_audio(n, yy, epoch, self.sr)
        self.experiment.flush()
        return