File size: 7,944 Bytes
b725c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import librosa
import torch

import numpy as np


def extract_mstft(
    audio_ref,
    audio_deg,
    fs=None,
    mid_freq=None,
    high_freq=None,
    method="cut",
    version="pwg",
):
    """Compute Multi-Scale STFT Distance (mstft) between the predicted and the ground truth audio.
    audio_ref: path to the ground truth audio.
    audio_deg: path to the predicted audio.
    fs: sampling rate.
    med_freq: division frequency for mid frequency parts.
    high_freq: division frequency for high frequency parts.
    method: "dtw" will use dtw algorithm to align the length of the ground truth and predicted audio.
            "cut" will cut both audios into a same length according to the one with the shorter length.
    version: "pwg" will use the computational version provided by ParallelWaveGAN.
             "encodec" will use the computational version provided by Encodec.
    """
    # Load audio
    if fs != None:
        audio_ref, _ = librosa.load(audio_ref, sr=fs)
        audio_deg, _ = librosa.load(audio_deg, sr=fs)
    else:
        audio_ref, fs = librosa.load(audio_ref)
        audio_deg, fs = librosa.load(audio_deg)

    # Automatically choose mid_freq and high_freq if they are not given
    if mid_freq == None:
        mid_freq = fs // 6
    if high_freq == None:
        high_freq = fs // 3

    # Audio length alignment
    if len(audio_ref) != len(audio_deg):
        if method == "cut":
            length = min(len(audio_ref), len(audio_deg))
            audio_ref = audio_ref[:length]
            audio_deg = audio_deg[:length]
        elif method == "dtw":
            _, wp = librosa.sequence.dtw(audio_ref, audio_deg, backtrack=True)
            audio_ref_new = []
            audio_deg_new = []
            for i in range(wp.shape[0]):
                ref_index = wp[i][0]
                deg_index = wp[i][1]
                audio_ref_new.append(audio_ref[ref_index])
                audio_deg_new.append(audio_deg[deg_index])
            audio_ref = np.array(audio_ref_new)
            audio_deg = np.array(audio_deg_new)
            assert len(audio_ref) == len(audio_deg)

    # Define loss function
    l1Loss = torch.nn.L1Loss(reduction="mean")
    l2Loss = torch.nn.MSELoss(reduction="mean")

    # Compute distance
    if version == "encodec":
        n_fft = 1024

        mstft = 0
        mstft_low = 0
        mstft_mid = 0
        mstft_high = 0

        freq_resolution = fs / n_fft
        mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
        high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))

        for i in range(5, 11):
            hop_length = 2**i // 4
            win_length = 2**i

            spec_ref = librosa.stft(
                y=audio_ref, n_fft=n_fft, hop_length=hop_length, win_length=win_length
            )
            spec_deg = librosa.stft(
                y=audio_deg, n_fft=n_fft, hop_length=hop_length, win_length=win_length
            )

            mag_ref = np.abs(spec_ref)
            mag_deg = np.abs(spec_deg)

            mag_ref = torch.from_numpy(mag_ref)
            mag_deg = torch.from_numpy(mag_deg)
            mstft += l1Loss(mag_ref, mag_deg) + l2Loss(mag_ref, mag_deg)

            mag_ref_low = mag_ref[:mid_freq_index, :]
            mag_deg_low = mag_deg[:mid_freq_index, :]
            mstft_low += l1Loss(mag_ref_low, mag_deg_low) + l2Loss(
                mag_ref_low, mag_deg_low
            )

            mag_ref_mid = mag_ref[mid_freq_index:high_freq_index, :]
            mag_deg_mid = mag_deg[mid_freq_index:high_freq_index, :]
            mstft_mid += l1Loss(mag_ref_mid, mag_deg_mid) + l2Loss(
                mag_ref_mid, mag_deg_mid
            )

            mag_ref_high = mag_ref[high_freq_index:, :]
            mag_deg_high = mag_deg[high_freq_index:, :]
            mstft_high += l1Loss(mag_ref_high, mag_deg_high) + l2Loss(
                mag_ref_high, mag_deg_high
            )

        mstft /= 6
        mstft_low /= 6
        mstft_mid /= 6
        mstft_high /= 6

        return mstft
    elif version == "pwg":
        fft_sizes = [1024, 2048, 512]
        hop_sizes = [120, 240, 50]
        win_sizes = [600, 1200, 240]

        audio_ref = torch.from_numpy(audio_ref)
        audio_deg = torch.from_numpy(audio_deg)

        mstft_sc = 0
        mstft_sc_low = 0
        mstft_sc_mid = 0
        mstft_sc_high = 0

        mstft_mag = 0
        mstft_mag_low = 0
        mstft_mag_mid = 0
        mstft_mag_high = 0

        for n_fft, hop_length, win_length in zip(fft_sizes, hop_sizes, win_sizes):
            spec_ref = torch.stft(
                audio_ref, n_fft, hop_length, win_length, return_complex=False
            )
            spec_deg = torch.stft(
                audio_deg, n_fft, hop_length, win_length, return_complex=False
            )

            real_ref = spec_ref[..., 0]
            imag_ref = spec_ref[..., 1]
            real_deg = spec_deg[..., 0]
            imag_deg = spec_deg[..., 1]

            mag_ref = torch.sqrt(
                torch.clamp(real_ref**2 + imag_ref**2, min=1e-7)
            ).transpose(1, 0)
            mag_deg = torch.sqrt(
                torch.clamp(real_deg**2 + imag_deg**2, min=1e-7)
            ).transpose(1, 0)
            sc_loss = torch.norm(mag_ref - mag_deg, p="fro") / torch.norm(
                mag_ref, p="fro"
            )
            mag_loss = l1Loss(torch.log(mag_ref), torch.log(mag_deg))

            mstft_sc += sc_loss
            mstft_mag += mag_loss

            freq_resolution = fs / n_fft
            mid_freq_index = 1 + int(np.floor(mid_freq / freq_resolution))
            high_freq_index = 1 + int(np.floor(high_freq / freq_resolution))

            mag_ref_low = mag_ref[:, :mid_freq_index]
            mag_deg_low = mag_deg[:, :mid_freq_index]
            sc_loss_low = torch.norm(mag_ref_low - mag_deg_low, p="fro") / torch.norm(
                mag_ref_low, p="fro"
            )
            mag_loss_low = l1Loss(torch.log(mag_ref_low), torch.log(mag_deg_low))

            mstft_sc_low += sc_loss_low
            mstft_mag_low += mag_loss_low

            mag_ref_mid = mag_ref[:, mid_freq_index:high_freq_index]
            mag_deg_mid = mag_deg[:, mid_freq_index:high_freq_index]
            sc_loss_mid = torch.norm(mag_ref_mid - mag_deg_mid, p="fro") / torch.norm(
                mag_ref_mid, p="fro"
            )
            mag_loss_mid = l1Loss(torch.log(mag_ref_mid), torch.log(mag_deg_mid))

            mstft_sc_mid += sc_loss_mid
            mstft_mag_mid += mag_loss_mid

            mag_ref_high = mag_ref[:, high_freq_index:]
            mag_deg_high = mag_deg[:, high_freq_index:]
            sc_loss_high = torch.norm(
                mag_ref_high - mag_deg_high, p="fro"
            ) / torch.norm(mag_ref_high, p="fro")
            mag_loss_high = l1Loss(torch.log(mag_ref_high), torch.log(mag_deg_high))

            mstft_sc_high += sc_loss_high
            mstft_mag_high += mag_loss_high

        # Normalize distances
        mstft_sc /= len(fft_sizes)
        mstft_sc_low /= len(fft_sizes)
        mstft_sc_mid /= len(fft_sizes)
        mstft_sc_high /= len(fft_sizes)

        mstft_mag /= len(fft_sizes)
        mstft_mag_low /= len(fft_sizes)
        mstft_mag_mid /= len(fft_sizes)
        mstft_mag_high /= len(fft_sizes)

        # return (
        #     mstft_sc.numpy().tolist(),
        #     mstft_sc_low.numpy().tolist(),
        #     mstft_sc_mid.numpy().tolist(),
        #     mstft_sc_high.numpy().tolist(),
        #     mstft_mag.numpy().tolist(),
        #     mstft_mag_low.numpy().tolist(),
        #     mstft_mag_mid.numpy().tolist(),
        #     mstft_mag_high.numpy().tolist(),
        # )

        return mstft_sc.numpy().tolist() + mstft_mag.numpy().tolist()