File size: 6,824 Bytes
d5ee97c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# -*- coding: utf-8 -*-
# Copyright 2020 Minh Nguyen (@dathudeptrai)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Griffin-Lim phase reconstruction algorithm from mel spectrogram."""

import os

import librosa
import numpy as np
import soundfile as sf
import tensorflow as tf
from sklearn.preprocessing import StandardScaler


def griffin_lim_lb(
    mel_spec, stats_path, dataset_config, n_iter=32, output_dir=None, wav_name="lb"
):
    """Generate wave from mel spectrogram with Griffin-Lim algorithm using Librosa.
    Args:
        mel_spec (ndarray): array representing the mel spectrogram.
        stats_path (str): path to the `stats.npy` file containing norm statistics.
        dataset_config (Dict): dataset configuration parameters.
        n_iter (int): number of iterations for GL.
        output_dir (str): output directory where audio file will be saved.
        wav_name (str): name of the output file.
    Returns:
        gl_lb (ndarray): generated wave.
    """
    scaler = StandardScaler()
    scaler.mean_, scaler.scale_ = np.load(stats_path)

    mel_spec = np.power(10.0, scaler.inverse_transform(mel_spec)).T
    mel_basis = librosa.filters.mel(
        dataset_config["sampling_rate"],
        n_fft=dataset_config["fft_size"],
        n_mels=dataset_config["num_mels"],
        fmin=dataset_config["fmin"],
        fmax=dataset_config["fmax"],
    )
    mel_to_linear = np.maximum(1e-10, np.dot(np.linalg.pinv(mel_basis), mel_spec))
    gl_lb = librosa.griffinlim(
        mel_to_linear,
        n_iter=n_iter,
        hop_length=dataset_config["hop_size"],
        win_length=dataset_config["win_length"] or dataset_config["fft_size"],
    )
    if output_dir:
        output_path = os.path.join(output_dir, f"{wav_name}.wav")
        sf.write(output_path, gl_lb, dataset_config["sampling_rate"], "PCM_16")
    return gl_lb


class TFGriffinLim(tf.keras.layers.Layer):
    """Griffin-Lim algorithm for phase reconstruction from mel spectrogram magnitude."""

    def __init__(self, stats_path, dataset_config, normalized: bool = True):
        """Init GL params.
        Args:
            stats_path (str): path to the `stats.npy` file containing norm statistics.
            dataset_config (Dict): dataset configuration parameters.
        """
        super().__init__()
        self.normalized = normalized
        if normalized:
            scaler = StandardScaler()
            scaler.mean_, scaler.scale_ = np.load(stats_path)
            self.scaler = scaler
        self.ds_config = dataset_config
        self.mel_basis = librosa.filters.mel(
            self.ds_config["sampling_rate"],
            n_fft=self.ds_config["fft_size"],
            n_mels=self.ds_config["num_mels"],
            fmin=self.ds_config["fmin"],
            fmax=self.ds_config["fmax"],
        )  # [num_mels, fft_size // 2 + 1]

    def save_wav(self, gl_tf, output_dir, wav_name):
        """Generate WAV file and save it.
        Args:
            gl_tf (tf.Tensor): reconstructed signal from GL algorithm.
            output_dir (str): output directory where audio file will be saved.
            wav_name (str): name of the output file.
        """
        encode_fn = lambda x: tf.audio.encode_wav(x, self.ds_config["sampling_rate"])
        gl_tf = tf.expand_dims(gl_tf, -1)
        if not isinstance(wav_name, list):
            wav_name = [wav_name]

        if len(gl_tf.shape) > 2:
            bs, *_ = gl_tf.shape
            assert bs == len(wav_name), "Batch and 'wav_name' have different size."
            tf_wav = tf.map_fn(encode_fn, gl_tf, dtype=tf.string)
            for idx in tf.range(bs):
                output_path = os.path.join(output_dir, f"{wav_name[idx]}.wav")
                tf.io.write_file(output_path, tf_wav[idx])
        else:
            tf_wav = encode_fn(gl_tf)
            tf.io.write_file(os.path.join(output_dir, f"{wav_name[0]}.wav"), tf_wav)

    @tf.function(
        input_signature=[
            tf.TensorSpec(shape=[None, None, None], dtype=tf.float32),
            tf.TensorSpec(shape=[], dtype=tf.int32),
        ]
    )
    def call(self, mel_spec, n_iter=32):
        """Apply GL algorithm to batched mel spectrograms.
        Args:
            mel_spec (tf.Tensor): normalized mel spectrogram.
            n_iter (int): number of iterations to run GL algorithm.
        Returns:
            (tf.Tensor): reconstructed signal from GL algorithm.
        """
        # de-normalize mel spectogram
        if self.normalized:
            mel_spec = tf.math.pow(
                10.0, mel_spec * self.scaler.scale_ + self.scaler.mean_
            )
        else:
            mel_spec = tf.math.pow(
                10.0, mel_spec
            )  # TODO @dathudeptrai check if its ok without it wavs were too quiet
        inverse_mel = tf.linalg.pinv(self.mel_basis)

        # [:, num_mels] @ [fft_size // 2 + 1, num_mels].T
        mel_to_linear = tf.linalg.matmul(mel_spec, inverse_mel, transpose_b=True)
        mel_to_linear = tf.cast(tf.math.maximum(1e-10, mel_to_linear), tf.complex64)

        init_phase = tf.cast(
            tf.random.uniform(tf.shape(mel_to_linear), maxval=1), tf.complex64
        )
        phase = tf.math.exp(2j * np.pi * init_phase)
        for _ in tf.range(n_iter):
            inverse = tf.signal.inverse_stft(
                mel_to_linear * phase,
                frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"],
                frame_step=self.ds_config["hop_size"],
                fft_length=self.ds_config["fft_size"],
                window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]),
            )
            phase = tf.signal.stft(
                inverse,
                self.ds_config["win_length"] or self.ds_config["fft_size"],
                self.ds_config["hop_size"],
                self.ds_config["fft_size"],
            )
            phase /= tf.cast(tf.maximum(1e-10, tf.abs(phase)), tf.complex64)

        return tf.signal.inverse_stft(
            mel_to_linear * phase,
            frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"],
            frame_step=self.ds_config["hop_size"],
            fft_length=self.ds_config["fft_size"],
            window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]),
        )