XDHDD commited on
Commit
e34c0af
1 Parent(s): 642d254

Upload 8 files

Browse files
PLCMOS/models/plcmos_v0.onnx ADDED
Binary file (691 kB). View file
 
PLCMOS/models/plcmos_v1_intrusive.onnx ADDED
Binary file (280 kB). View file
 
PLCMOS/models/plcmos_v1_nonintrusive.onnx ADDED
Binary file (129 kB). View file
 
PLCMOS/plc_mos.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+
4
+ import librosa
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ from numpy.fft import rfft
8
+ from numpy.lib.stride_tricks import as_strided
9
+
10
+ class PLCMOSEstimator():
11
+ def __init__(self, model_version=1):
12
+ """
13
+ Initialize a PLC-MOS model of a given version. There are currently three models available, v0 (intrusive)
14
+ and v1 (both non-intrusive and intrusive available). The default is to use the v1 models.
15
+ """
16
+
17
+ self.model_version = model_version
18
+ model_paths = [
19
+ # v0 model:
20
+ [("models/plcmos_v0.onnx", 999999999999), (None, 0)],
21
+
22
+ # v1 models:
23
+ [("models/plcmos_v1_intrusive.onnx", 768),
24
+ ("models/plcmos_v1_nonintrusive.onnx", 999999999999)],
25
+ ]
26
+ self.sessions = []
27
+ self.max_lens = []
28
+ options = ort.SessionOptions()
29
+ options.intra_op_num_threads = 8
30
+ options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
31
+ for path, max_len in model_paths[model_version]:
32
+ if not path is None:
33
+ file_dir = os.path.dirname(os.path.realpath(__file__))
34
+ self.sessions.append(ort.InferenceSession(
35
+ os.path.join(file_dir, path), options))
36
+ self.max_lens.append(max_len)
37
+ else:
38
+ self.sessions.append(None)
39
+ self.max_lens.append(0)
40
+
41
+ def logpow_dns(self, sig, floor=-30.):
42
+ """
43
+ Compute log power of complex spectrum.
44
+
45
+ Floor any -`np.inf` value to (nonzero minimum + `floor`) dB.
46
+ If all values are 0s, floor all values to -80 dB.
47
+ """
48
+ log10e = np.log10(np.e)
49
+ pspec = sig.real ** 2 + sig.imag ** 2
50
+ zeros = pspec == 0
51
+ logp = np.empty_like(pspec)
52
+ if np.any(~zeros):
53
+ logp[~zeros] = np.log(pspec[~zeros])
54
+ logp[zeros] = np.log(pspec[~zeros].min()) + floor / 10 / log10e
55
+ else:
56
+ logp.fill(-80 / 10 / log10e)
57
+
58
+ return logp
59
+
60
+ def hop2hsize(self, wind, hop):
61
+ """
62
+ Convert hop fraction to integer size if necessary.
63
+ """
64
+ if hop >= 1:
65
+ assert type(hop) == int, "Hop size must be integer!"
66
+ return hop
67
+ else:
68
+ assert 0 < hop < 1, "Hop fraction has to be in range (0,1)!"
69
+ return int(len(wind) * hop)
70
+
71
+ def stana(self, sig, sr, wind, hop, synth=False, center=False):
72
+ """
73
+ Short term analysis by windowing
74
+ """
75
+ ssize = len(sig)
76
+ fsize = len(wind)
77
+ hsize = self.hop2hsize(wind, hop)
78
+ if synth:
79
+ sstart = hsize - fsize # int(-fsize * (1-hfrac))
80
+ elif center:
81
+ sstart = -int(len(wind) / 2) # odd window centered at exactly n=0
82
+ else:
83
+ sstart = 0
84
+ send = ssize
85
+
86
+ nframe = math.ceil((send - sstart) / hsize)
87
+
88
+ # Calculate zero-padding sizes
89
+ zpleft = -sstart
90
+ zpright = (nframe - 1) * hsize + fsize - zpleft - ssize
91
+ if zpleft > 0 or zpright > 0:
92
+ sigpad = np.zeros(ssize + zpleft + zpright, dtype=sig.dtype)
93
+ sigpad[zpleft:len(sigpad) - zpright] = sig
94
+ else:
95
+ sigpad = sig
96
+
97
+ return as_strided(sigpad, shape=(nframe, fsize),
98
+ strides=(sig.itemsize * hsize, sig.itemsize)) * wind
99
+
100
+ def stft(self, sig, sr, wind, hop, nfft):
101
+ """
102
+ Compute STFT: window + rfft
103
+ """
104
+ frames = self.stana(sig, sr, wind, hop, synth=True)
105
+ return rfft(frames, n=nfft)
106
+
107
+ def stft_transform(self, audio, dft_size=512, hop_fraction=0.5, sr=16000):
108
+ """
109
+ Compute STFT parameters, then compute STFT
110
+ """
111
+ window = np.hamming(dft_size + 1)
112
+ window = window[:-1]
113
+ amp = np.abs(self.stft(audio, sr, window, hop_fraction, dft_size))
114
+ feat = self.logpow_dns(amp, floor=-120.)
115
+ return feat / 20.
116
+
117
+ def run(self, audio_degraded, audio_clean=None, combined=False):
118
+ """
119
+ Run the PLCMOS model and return the MOS for the given audio. If a clean audio file is passed and the
120
+ selected model version has an intrusive version, that version will be used, otherwise, the nonintrusive
121
+ model will be used. If combined is set to true (default), the mean of intrusive and nonintrusive models
122
+ results will be returned, when both are available
123
+
124
+ For intrusive models, the clean reference should be the unprocessed audio file the degraded audio is
125
+ based on. It is not required to be aligned with the degraded audio.
126
+
127
+ Audio data should be 16kHz, mono, [-1, 1] range.
128
+ """
129
+ audio_features_degraded = np.float32(self.stft_transform(audio_degraded))[
130
+ np.newaxis, np.newaxis, ...]
131
+ assert len(
132
+ audio_features_degraded) <= self.max_lens[0], "Maximum input length exceeded"
133
+
134
+ if audio_clean is None:
135
+ combined = False
136
+
137
+ mos = 0
138
+
139
+ session = self.sessions[0]
140
+ assert not session is None, "Intrusive model not available for this model version."
141
+ audio_features_clean = np.float32(self.stft_transform(audio_clean))[
142
+ np.newaxis, np.newaxis, ...]
143
+ assert len(
144
+ audio_features_clean) <= self.max_lens[0], "Maximum input length exceeded"
145
+ onnx_inputs = {"degraded_audio": audio_features_degraded,
146
+ "clean_audio": audio_features_clean}
147
+ mos = float(session.run(None, onnx_inputs)[0])
148
+
149
+ session = self.sessions[1]
150
+ assert not session is None, "Nonintrusive model not available for this model version."
151
+ onnx_inputs = {"degraded_audio": audio_features_degraded}
152
+ mos_2 = float(session.run(None, onnx_inputs)[0])
153
+ mos = [mos, mos_2]
154
+ return mos
utils/__init__.py ADDED
File without changes
utils/stft.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class STFTMag(nn.Module):
6
+ def __init__(self,
7
+ nfft=1024,
8
+ hop=256):
9
+ super().__init__()
10
+ self.nfft = nfft
11
+ self.hop = hop
12
+ self.register_buffer('window', torch.hann_window(nfft), False)
13
+
14
+ # x: [B,T] or [T]
15
+ @torch.no_grad()
16
+ def forward(self, x):
17
+ stft = torch.stft(x.cpu(),
18
+ self.nfft,
19
+ self.hop,
20
+ window=self.window,
21
+ ) # return_complex=False) #[B, F, TT,2]
22
+ mag = torch.norm(stft, p=2, dim=-1) # [B, F, TT]
23
+ return mag
utils/tblogger.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os import path
2
+
3
+ import librosa as rosa
4
+ import matplotlib
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from pytorch_lightning.loggers import TensorBoardLogger
8
+ from pytorch_lightning.utilities import rank_zero_only
9
+
10
+ from utils.stft import STFTMag
11
+
12
+ matplotlib.use('Agg')
13
+
14
+
15
+ class TensorBoardLoggerExpanded(TensorBoardLogger):
16
+ def __init__(self, sr=16000):
17
+ super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='')
18
+ self.sr = sr
19
+ self.stftmag = STFTMag()
20
+
21
+ def fig2np(self, fig):
22
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
23
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
24
+ return data
25
+
26
+ def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step):
27
+ name_list = ['y', 'y_low', 'y_recon']
28
+ fig = plt.figure(figsize=(9, 15))
29
+ fig.suptitle(f'Epoch_{step}')
30
+ for i, yy in enumerate([y, y_low, y_recon]):
31
+ if yy.dim() == 1:
32
+ yy = self.stftmag(yy)
33
+ ax = plt.subplot(3, 1, i + 1)
34
+ ax.set_title(name_list[i])
35
+ plt.imshow(rosa.amplitude_to_db(yy.numpy(),
36
+ ref=np.max, top_db=80.),
37
+ # vmin = -20,
38
+ vmax=0.,
39
+ aspect='auto',
40
+ origin='lower',
41
+ interpolation='none')
42
+ plt.colorbar()
43
+ plt.xlabel('Frames')
44
+ plt.ylabel('Channels')
45
+ plt.tight_layout()
46
+
47
+ fig.canvas.draw()
48
+ data = self.fig2np(fig)
49
+
50
+ plt.close()
51
+ return data
52
+
53
+ @rank_zero_only
54
+ def log_spectrogram(self, y, y_low, y_recon, epoch):
55
+ y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu()
56
+ spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch)
57
+ self.experiment.add_image(path.join(self.save_dir, 'result'),
58
+ spec_img,
59
+ epoch,
60
+ dataformats='HWC')
61
+ self.experiment.flush()
62
+ return
63
+
64
+ @rank_zero_only
65
+ def log_audio(self, y, y_low, y_recon, epoch):
66
+ y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(),
67
+ name_list = ['y', 'y_low', 'y_recon']
68
+ for n, yy in zip(name_list, [y, y_low, y_recon]):
69
+ self.experiment.add_audio(n, yy, epoch, self.sr)
70
+ self.experiment.flush()
71
+ return
utils/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import librosa
4
+ import librosa.display
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
8
+
9
+ from config import CONFIG
10
+
11
+
12
+ def mkdir_p(mypath):
13
+ """Creates a directory. equivalent to using mkdir -p on the command line"""
14
+
15
+ from errno import EEXIST
16
+ from os import makedirs, path
17
+
18
+ try:
19
+ makedirs(mypath)
20
+ except OSError as exc: # Python >2.5
21
+ if exc.errno == EEXIST and path.isdir(mypath):
22
+ pass
23
+ else:
24
+ raise
25
+
26
+
27
+ def visualize(target, input, recon, path):
28
+ sr = CONFIG.DATA.sr
29
+ window_size = 1024
30
+ window = np.hanning(window_size)
31
+
32
+ stft_hr = librosa.core.spectrum.stft(target, n_fft=window_size, hop_length=512, window=window)
33
+ stft_hr = 2 * np.abs(stft_hr) / np.sum(window)
34
+
35
+ stft_lr = librosa.core.spectrum.stft(input, n_fft=window_size, hop_length=512, window=window)
36
+ stft_lr = 2 * np.abs(stft_lr) / np.sum(window)
37
+
38
+ stft_recon = librosa.core.spectrum.stft(recon, n_fft=window_size, hop_length=512, window=window)
39
+ stft_recon = 2 * np.abs(stft_recon) / np.sum(window)
40
+
41
+ fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharey=True, sharex=True, figsize=(16, 10))
42
+ ax1.title.set_text('Target signal')
43
+ ax2.title.set_text('Lossy signal')
44
+ ax3.title.set_text('Reconstructed signal')
45
+
46
+ canvas = FigureCanvas(fig)
47
+ p = librosa.display.specshow(librosa.amplitude_to_db(stft_hr), ax=ax1, y_axis='linear', x_axis='time', sr=sr)
48
+ p = librosa.display.specshow(librosa.amplitude_to_db(stft_lr), ax=ax2, y_axis='linear', x_axis='time', sr=sr)
49
+ p = librosa.display.specshow(librosa.amplitude_to_db(stft_recon), ax=ax3, y_axis='linear', x_axis='time', sr=sr)
50
+ mkdir_p(path)
51
+ fig.savefig(os.path.join(path, 'spec.png'))
52
+
53
+
54
+ def get_power(x, nfft):
55
+ S = librosa.stft(x, n_fft=nfft)
56
+ S = np.log(np.abs(S) ** 2 + 1e-8)
57
+ return S
58
+
59
+
60
+ def LSD(x_hr, x_pr):
61
+ S1 = get_power(x_hr, nfft=2048)
62
+ S2 = get_power(x_pr, nfft=2048)
63
+ lsd = np.mean(np.sqrt(np.mean((S1 - S2) ** 2 + 1e-8, axis=-1)), axis=0)
64
+ S1 = S1[-(len(S1) - 1) // 2:, :]
65
+ S2 = S2[-(len(S2) - 1) // 2:, :]
66
+ lsd_high = np.mean(np.sqrt(np.mean((S1 - S2) ** 2 + 1e-8, axis=-1)), axis=0)
67
+ return lsd, lsd_high