Nithya commited on
Commit
97b6f36
·
1 Parent(s): a50a71e

testing feasibility

Browse files
Files changed (3) hide show
  1. app.py +268 -0
  2. src/generate_utils.py +88 -0
  3. src/pitch_to_audio_utils.py +121 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio import Interface, Audio
2
+ import gradio as gr
3
+ import numpy as np
4
+ import torch
5
+ import subprocess
6
+ import librosa
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ import os
10
+ from functools import partial
11
+ import gin
12
+ import sys
13
+ sys.path.append('./')
14
+ from src.generate_utils import invert_pitch_read, load_pitch_model, load_audio_model
15
+ import src.pitch_to_audio_utils as p2a
16
+ import torchaudio
17
+ from absl import app
18
+ from torch.nn.functional import interpolate
19
+ import pdb
20
+ import logging
21
+ import crepe
22
+ from hmmlearn import hmm
23
+ import time
24
+ import soundfile as sf
25
+
26
+ pitch_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4833583'
27
+ audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
28
+ pitch_primes = '/network/scratch/n/nithya.shikarpur/pitch-diffusion/data/merged_data-final/listening_study_primes.npz'
29
+ output_folder = '/network/scratch/n/nithya.shikarpur/pitch-diffusion/user-studies/listening-study-2/task-3'
30
+ device = 'cpu'
31
+
32
+ global_ind = -1
33
+ global_audios = np.array([0.0])
34
+ global_pitches = np.array([0])
35
+ singer = 3
36
+ audio_components = []
37
+ preprocessed_primes = []
38
+ selected_prime = None
39
+
40
+
41
+
42
+ def make_prime_npz(prime):
43
+ np.savez('./temp/prime.npz', concatenated_array=[[prime]])
44
+
45
+ def load_pitch_fns():
46
+ pitch_model, pitch_qt, _, pitch_task_fn = load_pitch_model(
47
+ os.path.join(pitch_path, 'config.gin'),
48
+ os.path.join(pitch_path, 'models', 'last.ckpt'),
49
+ os.path.join(pitch_path, 'qt.joblib'),
50
+ device=device
51
+ )
52
+ invert_pitch_fn = partial(
53
+ invert_pitch_read,
54
+ min_norm_pitch=gin.query_parameter('dataset.pitch_read_w_downsample.min_norm_pitch'),
55
+ time_downsample=gin.query_parameter('dataset.pitch_read_w_downsample.time_downsample'),
56
+ pitch_downsample=gin.query_parameter('dataset.pitch_read_w_downsample.pitch_downsample'),
57
+ qt_transform=pitch_qt,
58
+ min_clip=gin.query_parameter('dataset.pitch_read_w_downsample.min_clip'),
59
+ max_clip=gin.query_parameter('dataset.pitch_read_w_downsample.max_clip')
60
+ )
61
+ return pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn
62
+
63
+ def interpolate_pitch(pitch, audio_seq_len):
64
+ pitch = interpolate(pitch, size=audio_seq_len, mode='linear')
65
+ plt.plot(pitch[0].squeeze(0).detach().cpu().numpy())
66
+ plt.savefig(f"./temp/interpolated_pitch.png")
67
+ plt.close()
68
+ return pitch
69
+
70
+ def load_audio_fns():
71
+ ckpt = os.path.join(audio_path, 'models', 'checkpoint-epoch=3279-val_cross_entropy=0.00-cross_entropy=0.00.ckpt')
72
+ config = os.path.join(audio_path, 'config.gin')
73
+ qt = os.path.join(db_path_audio, 'qt.joblib')
74
+
75
+ audio_model, audio_qt = load_audio_model(config, ckpt, qt, device=device)
76
+ audio_seq_len = gin.query_parameter('%AUDIO_SEQ_LEN')
77
+
78
+ invert_audio_fn = partial(
79
+ p2a.normalized_mels_to_audio,
80
+ qt=audio_qt,
81
+ n_iter=200
82
+ )
83
+
84
+ return audio_model, audio_qt, audio_seq_len, invert_audio_fn
85
+
86
+ def predict_voicing(confidence):
87
+ # https://github.com/marl/crepe/pull/26
88
+ """
89
+ Find the Viterbi path for voiced versus unvoiced frames.
90
+ Parameters
91
+ ----------
92
+ confidence : np.ndarray [shape=(N,)]
93
+ voicing confidence array, i.e. the confidence in the presence of
94
+ a pitch
95
+ Returns
96
+ -------
97
+ voicing_states : np.ndarray [shape=(N,)]
98
+ HMM predictions for each frames state, 0 if unvoiced, 1 if
99
+ voiced
100
+ """
101
+ # uniform prior on the voicing confidence
102
+ starting = np.array([0.5, 0.5])
103
+
104
+ # transition probabilities inducing continuous voicing state
105
+ transition = np.array([[0.99, 0.01], [0.01, 0.99]])
106
+
107
+ # mean and variance for unvoiced and voiced states
108
+ means = np.array([[0.0], [1.0]])
109
+ variances = np.array([[0.25], [0.25]])
110
+
111
+ # fix the model parameters because we are not optimizing the model
112
+ model = hmm.GaussianHMM(n_components=2)
113
+ model.startprob_, model.covars_, model.transmat_, model.means_, \
114
+ model.n_features = starting, variances, transition, means, 1
115
+
116
+ # find the Viterbi path
117
+ voicing_states = model.predict(confidence.reshape(-1, 1), [len(confidence)])
118
+
119
+ return np.array(voicing_states)
120
+
121
+ def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True):
122
+ time, frequency, confidence, _ = crepe.predict(
123
+ audio, sr=sr,
124
+ viterbi=True,
125
+ step_size=frame_shift_ms,
126
+ verbose=0 if not log else 1)
127
+ f0 = frequency
128
+ if unvoice:
129
+ is_voiced = predict_voicing(confidence)
130
+ frequency_unvoiced = frequency * is_voiced
131
+ f0 = frequency_unvoiced
132
+
133
+ return time, f0, confidence
134
+
135
+ def generate_pitch(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, outfolder=None, processed_primes=None):
136
+ noisy_pitch = torch.Tensor(pitch[:, :, :1200]).to(pitch_model.device) + (torch.normal(mean=0.0, std=0.4*torch.ones(( 1200)))).to(pitch_model.device)
137
+ noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19)
138
+ samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps)
139
+ inverted_pitches = [invert_pitch_fn(samples.detach().cpu().numpy()[0])[0]]
140
+
141
+ if outfolder is not None:
142
+ os.makedirs(outfolder, exist_ok=True)
143
+ # pdb.set_trace()
144
+ for i, pitch in enumerate(inverted_pitches):
145
+ flattened_pitch = pitch.flatten()
146
+ pd.DataFrame({'f0': flattened_pitch}).to_csv(f"{outfolder}/{i}.csv", index=False)
147
+ plt.plot(np.where(flattened_pitch == 0, np.nan, flattened_pitch))
148
+ plt.savefig(f"{outfolder}/{i}.png")
149
+ plt.close()
150
+ return samples, inverted_pitches
151
+
152
+ def generate_audio(audio_model, f0s, invert_audio_fn, outfolder, singers=[3], num_steps=100):
153
+ singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
154
+ samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
155
+ audio = invert_audio_fn(samples)
156
+
157
+ if outfolder is not None:
158
+ os.makedirs(outfolder, exist_ok=True)
159
+ for i, a in enumerate(audio):
160
+ logging.log(logging.INFO, f"Saving audio {i}")
161
+ torchaudio.save(f"{outfolder}/{i}.wav", torch.tensor(a).detach().unsqueeze(0).cpu(), 16000)
162
+ return audio
163
+
164
+ def generate(pitch, num_samples=2, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ):
165
+ global global_ind, audio_components
166
+ global preprocessed_primes
167
+ # pdb.set_trace()
168
+ logging.log(logging.INFO, 'Generate function')
169
+ pitch, inverted_pitch = generate_pitch(pitch, pitch_model, invert_pitch_fn, 1, 100, outfolder=outfolder, processed_primes=selected_prime if global_ind != 0 else None)
170
+ if pitch_qt is not None:
171
+ def undo_qt(x, min_clip=200):
172
+ pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1)
173
+ pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
174
+ pitch[pitch < 200] = np.nan
175
+ return pitch
176
+ pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device)
177
+ interpolated_pitch = interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len)
178
+ interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)
179
+ interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
180
+ audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100, outfolder=outfolder)
181
+ # pdb.set_trace()
182
+ audio = audio.detach().cpu().numpy()[:, :]
183
+ pitch = pitch.detach().cpu().numpy()
184
+ # state = [(16000, audio[0]), (16000, audio[1])]
185
+ # pdb.set_trace()
186
+ pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten())
187
+ fig1 = plt.figure()
188
+ # plt.plot(np.arange(0, 400), pitch_vals[:400], figure=fig1, label='User Input')
189
+ plt.plot(pitch_vals, figure=fig1, label='Pitch')
190
+ # plt.legend(fig1)
191
+ # state.append(fig1)
192
+ plt.close(fig1)
193
+ return (16000, audio[0]), fig1, pitch_vals
194
+
195
+ pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn = load_pitch_fns()
196
+ audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns()
197
+ partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder='temp', pitch_qt=pitch_qt)
198
+
199
+ def set_prime_and_generate(audio, full_pitch, full_audio, full_user):
200
+ global selected_prime, pitch_task_fn
201
+
202
+ if audio is None:
203
+ return None, None
204
+ sr, audio = audio
205
+ if len(audio) < 12*sr:
206
+ audio = np.pad(audio, (0, 12*sr - len(audio)), mode='constant')
207
+
208
+ audio = audio.astype(np.float32)
209
+ audio /= np.max(np.abs(audio))
210
+ audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
211
+ mic_audio = audio.copy()
212
+ audio = audio[-12*16000:]
213
+ _, f0, _ = extract_pitch(audio)
214
+ mic_f0 = f0.copy()
215
+ f0 = pitch_task_fn({
216
+ 'pitch': {
217
+ 'data': f0,
218
+ 'sampling_rate': 100
219
+ }
220
+ }, qt_transform=pitch_qt)
221
+ f0 = f0.reshape(1, 1, -1)
222
+ f0 = torch.tensor(f0).to(pitch_model.device).float()
223
+ audio, pitch, pitch_vals = partial_generate(f0)
224
+ # pdb.set_trace()
225
+ full_pitch = np.concatenate((full_pitch, mic_f0, pitch_vals))
226
+ full_user = np.concatenate((full_user, ['User'] * len(mic_f0), ['Model'] * len(pitch_vals)))
227
+ full_audio[1] = np.concatenate((full_audio[1], mic_audio, audio[1]))
228
+ # pdb.set_trace()
229
+ fig = plt.figure()
230
+ plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=fig)
231
+ plt.close(fig)
232
+ return audio, pitch, full_pitch, full_audio, full_user, fig
233
+
234
+ def save_session(full_pitch, full_audio, full_user):
235
+ os.makedirs(output_folder, exist_ok=True)
236
+ filename = f'session-{time.time()}'
237
+ logging.log(logging.INFO, f"Saving session to {filename}")
238
+ pd.DataFrame({'pitch': full_pitch, 'time': np.arange(0, len(full_pitch)/100, 0.01), 'user': full_user}).to_csv(os.path.join(output_folder, filename + '.csv'), index=False)
239
+ sf.write(os.path.join(output_folder, filename + '.wav'), full_audio[1], 16000)
240
+
241
+ with gr.Blocks() as demo:
242
+ full_audio = gr.State((16000, np.array([])))
243
+ full_pitch = gr.State(np.array([]))
244
+ full_user = gr.State(np.array([]))
245
+ with gr.Row():
246
+ with gr.Column():
247
+ audio = gr.Audio(label="Input")
248
+ sbmt = gr.Button()
249
+ user_input = gr.Plot(label="User Input")
250
+ with gr.Column():
251
+ generated_audio = gr.Audio(label="Generated Audio")
252
+ generated_pitch = gr.Plot(label="Generated Pitch")
253
+ sbmt.click(set_prime_and_generate, inputs=[audio, full_pitch, full_audio, full_user], outputs=[generated_audio, generated_pitch, full_pitch, full_audio, full_user, user_input])
254
+ save = gr.Button("Save Session")
255
+ save.click(save_session, inputs=[full_pitch, full_audio, full_user])
256
+
257
+
258
+
259
+ def main(argv):
260
+ # audio = np.random.randint(0, high=128, size=(44100*5), dtype=np.int16)
261
+ # sr = 44100
262
+ # pdb.set_trace()
263
+ # p, a = set_prime_and_generate((sr, audio))
264
+
265
+ demo.launch(share=True)
266
+
267
+ if __name__ == '__main__':
268
+ app.run(main)
src/generate_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Optional
3
+ from sklearn.preprocessing import QuantileTransformer
4
+ import sys
5
+ import pdb
6
+ sys.path.append('../pitch-diffusion')
7
+ import torch
8
+ import gin
9
+ from src.model import UNet, UNetPitchConditioned
10
+ from functools import partial
11
+ import joblib
12
+ from src.dataset import hz_to_cents, pitch_read_w_downsample
13
+
14
+ def invert_pitch_read(pitch,
15
+ min_norm_pitch: int,
16
+ time_downsample: int,
17
+ pitch_downsample: int,
18
+ qt_transform: Optional[QuantileTransformer],
19
+ min_clip: int,
20
+ max_clip: int):
21
+ try:
22
+ pitch = pitch.detach().cpu().numpy()
23
+ except:
24
+ pass
25
+ if qt_transform is not None:
26
+ pitch = qt_transform.inverse_transform(pitch.reshape(-1, 1))
27
+ pitch.reshape(1, -1)
28
+ pitch[pitch < min_clip] = np.nan
29
+ pitch[~np.isnan(pitch)] = (pitch[~np.isnan(pitch)] - 1) * pitch_downsample
30
+ pitch[~np.isnan(pitch)] = pitch[~np.isnan(pitch)] + min_norm_pitch
31
+ pitch[~np.isnan(pitch)] = 440 * 2**(pitch[~np.isnan(pitch)] / 1200)
32
+ pitch[np.isnan(pitch)] = 0
33
+
34
+ return pitch, 200//time_downsample
35
+
36
+ def invert_tonic(tonic: Optional[int] = None,
37
+ min_norm_pitch: int = 0,
38
+ min_clip: int = 200,
39
+ pitch_downsample: int = 1,
40
+ ):
41
+ tonic += min_clip
42
+ tonic = pitch_downsample * (tonic - 1)
43
+ tonic += min_norm_pitch
44
+ tonic = 440 * 2**(tonic / 1200)
45
+
46
+ return tonic
47
+
48
+ def load_processed_pitch(pitch,
49
+ audio_seq_len: int,
50
+ min_norm_pitch: int,
51
+ pitch_downsample: int,
52
+ min_clip: int,
53
+ max_clip: int,
54
+ ):
55
+ # pdb.set_trace()
56
+ pitch = hz_to_cents(pitch, min_norm_pitch=min_norm_pitch, min_clip=min_clip, max_clip=max_clip, pitch_downsample=pitch_downsample, silence_token=min_clip-4)
57
+ pitch_inds = np.linspace(0, pitch.shape[0], num=audio_seq_len, endpoint=False)
58
+ pitch = np.interp(pitch_inds, np.arange(0, pitch.shape[0]), pitch)
59
+ return pitch
60
+
61
+ def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
62
+ gin.parse_config_file(config)
63
+ model = UNet()
64
+ model.load_state_dict(torch.load(ckpt)['state_dict'])
65
+ model.to(device)
66
+ if qt is not None:
67
+ qt = joblib.load(qt)
68
+ if prime_file is not None:
69
+ with gin.config_scope('val'): # probably have to change this
70
+ with gin.unlock_config():
71
+ gin.bind_parameter('dataset.pitch_read_w_downsample.qt_transform', qt)
72
+ primes = np.load(prime_file, allow_pickle=True)['concatenated_array'][:, 0]
73
+ else:
74
+ primes = None
75
+ task_fn = None
76
+ task_fn = partial(pitch_read_w_downsample,
77
+ seq_len=None)
78
+ return model, qt, primes, task_fn
79
+
80
+ def load_audio_model(config, ckpt, qt = None, device='cuda'):
81
+ gin.parse_config_file(config)
82
+ model = UNetPitchConditioned() # there are no gin parameters for some reason
83
+ model.load_state_dict(torch.load(ckpt)['state_dict'])
84
+ model.to(device)
85
+ if qt is not None:
86
+ qt = joblib.load(qt)
87
+
88
+ return model, qt
src/pitch_to_audio_utils.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import librosa as li
3
+ import torch
4
+ from tqdm import tqdm
5
+ import numpy as np
6
+ import gin
7
+ import logging
8
+
9
+ import pdb
10
+
11
+ @gin.configurable
12
+ def torch_stft(x, nfft):
13
+ window = torch.hann_window(nfft).to(x)
14
+ x = torch.stft(
15
+ x,
16
+ n_fft=nfft,
17
+ hop_length=nfft // 4,
18
+ win_length=nfft,
19
+ window=window,
20
+ center=True,
21
+ return_complex=True,
22
+ )
23
+ x = 2 * x / torch.mean(window)
24
+ return x
25
+
26
+ @gin.configurable
27
+ def torch_istft(x, nfft):
28
+ # pdb.set_trace()
29
+ window = torch.hann_window(nfft).to(x.device)
30
+ x = x / 2 * torch.mean(window)
31
+ return torch.istft(
32
+ x,
33
+ n_fft=nfft,
34
+ hop_length=nfft // 4,
35
+ win_length=nfft,
36
+ window=window,
37
+ center=True,
38
+ )
39
+
40
+ @gin.configurable
41
+ def to_mels(stft, nfft, num_mels, sr, eps=1e-2):
42
+ mels = li.filters.mel(
43
+ sr=sr,
44
+ n_fft=nfft,
45
+ n_mels=num_mels,
46
+ fmin=40,
47
+ )
48
+ # pdb.set_trace()
49
+ mels = torch.from_numpy(mels).to(stft)
50
+ mel_stft = torch.einsum("mf,bft->bmt", mels, stft)
51
+ mel_stft = torch.log(mel_stft + eps)
52
+ return mel_stft
53
+
54
+ @gin.configurable
55
+ def from_mels(mel_stft, nfft, num_mels, sr, eps=1e-2):
56
+ mels = li.filters.mel(
57
+ sr=sr,
58
+ n_fft=nfft,
59
+ n_mels=num_mels,
60
+ fmin=40,
61
+ )
62
+ mels = torch.from_numpy(mels).to(mel_stft)
63
+ mels = torch.pinverse(mels)
64
+ mel_stft = torch.exp(mel_stft) - eps
65
+ stft = torch.einsum("fm,bmt->bft", mels, mel_stft)
66
+ return stft
67
+
68
+ @gin.configurable
69
+ def torch_gl(stft, nfft, sr, n_iter):
70
+
71
+ def _gl_iter(phase, xs, stft):
72
+ del xs
73
+ # pdb.set_trace()
74
+ c_stft = stft * torch.exp(1j * phase)
75
+ rec = torch_istft(c_stft, nfft)
76
+ r_stft = torch_stft(rec, nfft)
77
+ phase = torch.angle(r_stft)
78
+ return phase, None
79
+
80
+ phase = torch.rand_like(stft) * 2 * torch.pi
81
+
82
+ for _ in tqdm(range(n_iter)):
83
+ phase, _ = _gl_iter(phase, None, stft)
84
+
85
+ c_stft = stft * torch.exp(1j * phase)
86
+ audio = torch_istft(c_stft, nfft)
87
+
88
+ return audio
89
+
90
+ @gin.configurable
91
+ def normalize(x, qt=None):
92
+ x_flat = x.reshape(-1, 1)
93
+ if qt is None:
94
+ logging.warning('No quantile transformer found, returning input')
95
+ return x
96
+ return torch.Tensor(qt.transform(x_flat).reshape(x.shape))
97
+
98
+ @gin.configurable
99
+ def unnormalize(x, qt=None):
100
+ x_flat = x.reshape(-1, 1)
101
+ if qt is None:
102
+ logging.warning('No quantile transformer found, returning input')
103
+ return x
104
+ if isinstance(x_flat, torch.Tensor):
105
+ x_flat = x_flat.detach().cpu().numpy()
106
+ return torch.Tensor(qt.inverse_transform(x_flat).reshape(x.shape))
107
+
108
+ @gin.configurable
109
+ def audio_to_normalized_mels(x, nfft, num_mels, sr, qt):
110
+ # pdb.set_trace()
111
+ stfts = torch_stft(x, nfft=nfft).abs()[..., :-1]
112
+ mel_stfts = to_mels(stfts, nfft, num_mels, sr)
113
+ return normalize(mel_stfts, qt).to(x)
114
+
115
+ @gin.configurable
116
+ def normalized_mels_to_audio(x, nfft, num_mels, sr, qt, n_iter=20):
117
+ x = unnormalize(x, qt).to(x)
118
+ x = from_mels(x, nfft, num_mels, sr)
119
+ x = torch.clamp(x, 0, nfft)
120
+ x = torch_gl(x, nfft, sr, n_iter=n_iter)
121
+ return x