Spaces:
Sleeping
Sleeping
Nithya
commited on
Commit
·
97b6f36
1
Parent(s):
a50a71e
testing feasibility
Browse files- app.py +268 -0
- src/generate_utils.py +88 -0
- src/pitch_to_audio_utils.py +121 -0
app.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio import Interface, Audio
|
2 |
+
import gradio as gr
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import subprocess
|
6 |
+
import librosa
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import pandas as pd
|
9 |
+
import os
|
10 |
+
from functools import partial
|
11 |
+
import gin
|
12 |
+
import sys
|
13 |
+
sys.path.append('./')
|
14 |
+
from src.generate_utils import invert_pitch_read, load_pitch_model, load_audio_model
|
15 |
+
import src.pitch_to_audio_utils as p2a
|
16 |
+
import torchaudio
|
17 |
+
from absl import app
|
18 |
+
from torch.nn.functional import interpolate
|
19 |
+
import pdb
|
20 |
+
import logging
|
21 |
+
import crepe
|
22 |
+
from hmmlearn import hmm
|
23 |
+
import time
|
24 |
+
import soundfile as sf
|
25 |
+
|
26 |
+
pitch_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4833583'
|
27 |
+
audio_path = '/network/scratch/n/nithya.shikarpur/checkpoints/pitch-diffusion/corrected-attention-v3/4835364'
|
28 |
+
pitch_primes = '/network/scratch/n/nithya.shikarpur/pitch-diffusion/data/merged_data-final/listening_study_primes.npz'
|
29 |
+
output_folder = '/network/scratch/n/nithya.shikarpur/pitch-diffusion/user-studies/listening-study-2/task-3'
|
30 |
+
device = 'cpu'
|
31 |
+
|
32 |
+
global_ind = -1
|
33 |
+
global_audios = np.array([0.0])
|
34 |
+
global_pitches = np.array([0])
|
35 |
+
singer = 3
|
36 |
+
audio_components = []
|
37 |
+
preprocessed_primes = []
|
38 |
+
selected_prime = None
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
def make_prime_npz(prime):
|
43 |
+
np.savez('./temp/prime.npz', concatenated_array=[[prime]])
|
44 |
+
|
45 |
+
def load_pitch_fns():
|
46 |
+
pitch_model, pitch_qt, _, pitch_task_fn = load_pitch_model(
|
47 |
+
os.path.join(pitch_path, 'config.gin'),
|
48 |
+
os.path.join(pitch_path, 'models', 'last.ckpt'),
|
49 |
+
os.path.join(pitch_path, 'qt.joblib'),
|
50 |
+
device=device
|
51 |
+
)
|
52 |
+
invert_pitch_fn = partial(
|
53 |
+
invert_pitch_read,
|
54 |
+
min_norm_pitch=gin.query_parameter('dataset.pitch_read_w_downsample.min_norm_pitch'),
|
55 |
+
time_downsample=gin.query_parameter('dataset.pitch_read_w_downsample.time_downsample'),
|
56 |
+
pitch_downsample=gin.query_parameter('dataset.pitch_read_w_downsample.pitch_downsample'),
|
57 |
+
qt_transform=pitch_qt,
|
58 |
+
min_clip=gin.query_parameter('dataset.pitch_read_w_downsample.min_clip'),
|
59 |
+
max_clip=gin.query_parameter('dataset.pitch_read_w_downsample.max_clip')
|
60 |
+
)
|
61 |
+
return pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn
|
62 |
+
|
63 |
+
def interpolate_pitch(pitch, audio_seq_len):
|
64 |
+
pitch = interpolate(pitch, size=audio_seq_len, mode='linear')
|
65 |
+
plt.plot(pitch[0].squeeze(0).detach().cpu().numpy())
|
66 |
+
plt.savefig(f"./temp/interpolated_pitch.png")
|
67 |
+
plt.close()
|
68 |
+
return pitch
|
69 |
+
|
70 |
+
def load_audio_fns():
|
71 |
+
ckpt = os.path.join(audio_path, 'models', 'checkpoint-epoch=3279-val_cross_entropy=0.00-cross_entropy=0.00.ckpt')
|
72 |
+
config = os.path.join(audio_path, 'config.gin')
|
73 |
+
qt = os.path.join(db_path_audio, 'qt.joblib')
|
74 |
+
|
75 |
+
audio_model, audio_qt = load_audio_model(config, ckpt, qt, device=device)
|
76 |
+
audio_seq_len = gin.query_parameter('%AUDIO_SEQ_LEN')
|
77 |
+
|
78 |
+
invert_audio_fn = partial(
|
79 |
+
p2a.normalized_mels_to_audio,
|
80 |
+
qt=audio_qt,
|
81 |
+
n_iter=200
|
82 |
+
)
|
83 |
+
|
84 |
+
return audio_model, audio_qt, audio_seq_len, invert_audio_fn
|
85 |
+
|
86 |
+
def predict_voicing(confidence):
|
87 |
+
# https://github.com/marl/crepe/pull/26
|
88 |
+
"""
|
89 |
+
Find the Viterbi path for voiced versus unvoiced frames.
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
confidence : np.ndarray [shape=(N,)]
|
93 |
+
voicing confidence array, i.e. the confidence in the presence of
|
94 |
+
a pitch
|
95 |
+
Returns
|
96 |
+
-------
|
97 |
+
voicing_states : np.ndarray [shape=(N,)]
|
98 |
+
HMM predictions for each frames state, 0 if unvoiced, 1 if
|
99 |
+
voiced
|
100 |
+
"""
|
101 |
+
# uniform prior on the voicing confidence
|
102 |
+
starting = np.array([0.5, 0.5])
|
103 |
+
|
104 |
+
# transition probabilities inducing continuous voicing state
|
105 |
+
transition = np.array([[0.99, 0.01], [0.01, 0.99]])
|
106 |
+
|
107 |
+
# mean and variance for unvoiced and voiced states
|
108 |
+
means = np.array([[0.0], [1.0]])
|
109 |
+
variances = np.array([[0.25], [0.25]])
|
110 |
+
|
111 |
+
# fix the model parameters because we are not optimizing the model
|
112 |
+
model = hmm.GaussianHMM(n_components=2)
|
113 |
+
model.startprob_, model.covars_, model.transmat_, model.means_, \
|
114 |
+
model.n_features = starting, variances, transition, means, 1
|
115 |
+
|
116 |
+
# find the Viterbi path
|
117 |
+
voicing_states = model.predict(confidence.reshape(-1, 1), [len(confidence)])
|
118 |
+
|
119 |
+
return np.array(voicing_states)
|
120 |
+
|
121 |
+
def extract_pitch(audio, unvoice=True, sr=16000, frame_shift_ms=10, log=True):
|
122 |
+
time, frequency, confidence, _ = crepe.predict(
|
123 |
+
audio, sr=sr,
|
124 |
+
viterbi=True,
|
125 |
+
step_size=frame_shift_ms,
|
126 |
+
verbose=0 if not log else 1)
|
127 |
+
f0 = frequency
|
128 |
+
if unvoice:
|
129 |
+
is_voiced = predict_voicing(confidence)
|
130 |
+
frequency_unvoiced = frequency * is_voiced
|
131 |
+
f0 = frequency_unvoiced
|
132 |
+
|
133 |
+
return time, f0, confidence
|
134 |
+
|
135 |
+
def generate_pitch(pitch, pitch_model, invert_pitch_fn, num_samples, num_steps, outfolder=None, processed_primes=None):
|
136 |
+
noisy_pitch = torch.Tensor(pitch[:, :, :1200]).to(pitch_model.device) + (torch.normal(mean=0.0, std=0.4*torch.ones(( 1200)))).to(pitch_model.device)
|
137 |
+
noisy_pitch = torch.clamp(noisy_pitch, -5.19, 5.19)
|
138 |
+
samples = pitch_model.sample_sdedit(noisy_pitch, num_samples, num_steps)
|
139 |
+
inverted_pitches = [invert_pitch_fn(samples.detach().cpu().numpy()[0])[0]]
|
140 |
+
|
141 |
+
if outfolder is not None:
|
142 |
+
os.makedirs(outfolder, exist_ok=True)
|
143 |
+
# pdb.set_trace()
|
144 |
+
for i, pitch in enumerate(inverted_pitches):
|
145 |
+
flattened_pitch = pitch.flatten()
|
146 |
+
pd.DataFrame({'f0': flattened_pitch}).to_csv(f"{outfolder}/{i}.csv", index=False)
|
147 |
+
plt.plot(np.where(flattened_pitch == 0, np.nan, flattened_pitch))
|
148 |
+
plt.savefig(f"{outfolder}/{i}.png")
|
149 |
+
plt.close()
|
150 |
+
return samples, inverted_pitches
|
151 |
+
|
152 |
+
def generate_audio(audio_model, f0s, invert_audio_fn, outfolder, singers=[3], num_steps=100):
|
153 |
+
singer_tensor = torch.tensor(np.repeat(singers, repeats=f0s.shape[0])).to(audio_model.device)
|
154 |
+
samples, _, singers = audio_model.sample_cfg(f0s.shape[0], f0=f0s, num_steps=num_steps, singer=singer_tensor, strength=3)
|
155 |
+
audio = invert_audio_fn(samples)
|
156 |
+
|
157 |
+
if outfolder is not None:
|
158 |
+
os.makedirs(outfolder, exist_ok=True)
|
159 |
+
for i, a in enumerate(audio):
|
160 |
+
logging.log(logging.INFO, f"Saving audio {i}")
|
161 |
+
torchaudio.save(f"{outfolder}/{i}.wav", torch.tensor(a).detach().unsqueeze(0).cpu(), 16000)
|
162 |
+
return audio
|
163 |
+
|
164 |
+
def generate(pitch, num_samples=2, num_steps=100, singers=[3], outfolder='temp', audio_seq_len=750, pitch_qt=None ):
|
165 |
+
global global_ind, audio_components
|
166 |
+
global preprocessed_primes
|
167 |
+
# pdb.set_trace()
|
168 |
+
logging.log(logging.INFO, 'Generate function')
|
169 |
+
pitch, inverted_pitch = generate_pitch(pitch, pitch_model, invert_pitch_fn, 1, 100, outfolder=outfolder, processed_primes=selected_prime if global_ind != 0 else None)
|
170 |
+
if pitch_qt is not None:
|
171 |
+
def undo_qt(x, min_clip=200):
|
172 |
+
pitch= pitch_qt.inverse_transform(x.reshape(-1, 1)).reshape(1, -1)
|
173 |
+
pitch = np.around(pitch) # round to nearest integer, done in preprocessing of pitch contour fed into model
|
174 |
+
pitch[pitch < 200] = np.nan
|
175 |
+
return pitch
|
176 |
+
pitch = torch.tensor(np.array([undo_qt(x) for x in pitch.detach().cpu().numpy()])).to(pitch_model.device)
|
177 |
+
interpolated_pitch = interpolate_pitch(pitch=pitch, audio_seq_len=audio_seq_len)
|
178 |
+
interpolated_pitch = torch.nan_to_num(interpolated_pitch, nan=196)
|
179 |
+
interpolated_pitch = interpolated_pitch.squeeze(1) # to match input size by removing the extra dimension
|
180 |
+
audio = generate_audio(audio_model, interpolated_pitch, invert_audio_fn, singers=singers, num_steps=100, outfolder=outfolder)
|
181 |
+
# pdb.set_trace()
|
182 |
+
audio = audio.detach().cpu().numpy()[:, :]
|
183 |
+
pitch = pitch.detach().cpu().numpy()
|
184 |
+
# state = [(16000, audio[0]), (16000, audio[1])]
|
185 |
+
# pdb.set_trace()
|
186 |
+
pitch_vals = np.where(pitch[0][:, 0] == 0, np.nan, pitch[0].flatten())
|
187 |
+
fig1 = plt.figure()
|
188 |
+
# plt.plot(np.arange(0, 400), pitch_vals[:400], figure=fig1, label='User Input')
|
189 |
+
plt.plot(pitch_vals, figure=fig1, label='Pitch')
|
190 |
+
# plt.legend(fig1)
|
191 |
+
# state.append(fig1)
|
192 |
+
plt.close(fig1)
|
193 |
+
return (16000, audio[0]), fig1, pitch_vals
|
194 |
+
|
195 |
+
pitch_model, pitch_qt, pitch_task_fn, invert_pitch_fn = load_pitch_fns()
|
196 |
+
audio_model, audio_qt, audio_seq_len, invert_audio_fn = load_audio_fns()
|
197 |
+
partial_generate = partial(generate, num_samples=1, num_steps=100, singers=[3], outfolder='temp', pitch_qt=pitch_qt)
|
198 |
+
|
199 |
+
def set_prime_and_generate(audio, full_pitch, full_audio, full_user):
|
200 |
+
global selected_prime, pitch_task_fn
|
201 |
+
|
202 |
+
if audio is None:
|
203 |
+
return None, None
|
204 |
+
sr, audio = audio
|
205 |
+
if len(audio) < 12*sr:
|
206 |
+
audio = np.pad(audio, (0, 12*sr - len(audio)), mode='constant')
|
207 |
+
|
208 |
+
audio = audio.astype(np.float32)
|
209 |
+
audio /= np.max(np.abs(audio))
|
210 |
+
audio = librosa.resample(audio, orig_sr=sr, target_sr=16000) # convert only last 4 s
|
211 |
+
mic_audio = audio.copy()
|
212 |
+
audio = audio[-12*16000:]
|
213 |
+
_, f0, _ = extract_pitch(audio)
|
214 |
+
mic_f0 = f0.copy()
|
215 |
+
f0 = pitch_task_fn({
|
216 |
+
'pitch': {
|
217 |
+
'data': f0,
|
218 |
+
'sampling_rate': 100
|
219 |
+
}
|
220 |
+
}, qt_transform=pitch_qt)
|
221 |
+
f0 = f0.reshape(1, 1, -1)
|
222 |
+
f0 = torch.tensor(f0).to(pitch_model.device).float()
|
223 |
+
audio, pitch, pitch_vals = partial_generate(f0)
|
224 |
+
# pdb.set_trace()
|
225 |
+
full_pitch = np.concatenate((full_pitch, mic_f0, pitch_vals))
|
226 |
+
full_user = np.concatenate((full_user, ['User'] * len(mic_f0), ['Model'] * len(pitch_vals)))
|
227 |
+
full_audio[1] = np.concatenate((full_audio[1], mic_audio, audio[1]))
|
228 |
+
# pdb.set_trace()
|
229 |
+
fig = plt.figure()
|
230 |
+
plt.plot(np.arange(0, len(mic_f0)), mic_f0, label='User Input', figure=fig)
|
231 |
+
plt.close(fig)
|
232 |
+
return audio, pitch, full_pitch, full_audio, full_user, fig
|
233 |
+
|
234 |
+
def save_session(full_pitch, full_audio, full_user):
|
235 |
+
os.makedirs(output_folder, exist_ok=True)
|
236 |
+
filename = f'session-{time.time()}'
|
237 |
+
logging.log(logging.INFO, f"Saving session to {filename}")
|
238 |
+
pd.DataFrame({'pitch': full_pitch, 'time': np.arange(0, len(full_pitch)/100, 0.01), 'user': full_user}).to_csv(os.path.join(output_folder, filename + '.csv'), index=False)
|
239 |
+
sf.write(os.path.join(output_folder, filename + '.wav'), full_audio[1], 16000)
|
240 |
+
|
241 |
+
with gr.Blocks() as demo:
|
242 |
+
full_audio = gr.State((16000, np.array([])))
|
243 |
+
full_pitch = gr.State(np.array([]))
|
244 |
+
full_user = gr.State(np.array([]))
|
245 |
+
with gr.Row():
|
246 |
+
with gr.Column():
|
247 |
+
audio = gr.Audio(label="Input")
|
248 |
+
sbmt = gr.Button()
|
249 |
+
user_input = gr.Plot(label="User Input")
|
250 |
+
with gr.Column():
|
251 |
+
generated_audio = gr.Audio(label="Generated Audio")
|
252 |
+
generated_pitch = gr.Plot(label="Generated Pitch")
|
253 |
+
sbmt.click(set_prime_and_generate, inputs=[audio, full_pitch, full_audio, full_user], outputs=[generated_audio, generated_pitch, full_pitch, full_audio, full_user, user_input])
|
254 |
+
save = gr.Button("Save Session")
|
255 |
+
save.click(save_session, inputs=[full_pitch, full_audio, full_user])
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
def main(argv):
|
260 |
+
# audio = np.random.randint(0, high=128, size=(44100*5), dtype=np.int16)
|
261 |
+
# sr = 44100
|
262 |
+
# pdb.set_trace()
|
263 |
+
# p, a = set_prime_and_generate((sr, audio))
|
264 |
+
|
265 |
+
demo.launch(share=True)
|
266 |
+
|
267 |
+
if __name__ == '__main__':
|
268 |
+
app.run(main)
|
src/generate_utils.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import Optional
|
3 |
+
from sklearn.preprocessing import QuantileTransformer
|
4 |
+
import sys
|
5 |
+
import pdb
|
6 |
+
sys.path.append('../pitch-diffusion')
|
7 |
+
import torch
|
8 |
+
import gin
|
9 |
+
from src.model import UNet, UNetPitchConditioned
|
10 |
+
from functools import partial
|
11 |
+
import joblib
|
12 |
+
from src.dataset import hz_to_cents, pitch_read_w_downsample
|
13 |
+
|
14 |
+
def invert_pitch_read(pitch,
|
15 |
+
min_norm_pitch: int,
|
16 |
+
time_downsample: int,
|
17 |
+
pitch_downsample: int,
|
18 |
+
qt_transform: Optional[QuantileTransformer],
|
19 |
+
min_clip: int,
|
20 |
+
max_clip: int):
|
21 |
+
try:
|
22 |
+
pitch = pitch.detach().cpu().numpy()
|
23 |
+
except:
|
24 |
+
pass
|
25 |
+
if qt_transform is not None:
|
26 |
+
pitch = qt_transform.inverse_transform(pitch.reshape(-1, 1))
|
27 |
+
pitch.reshape(1, -1)
|
28 |
+
pitch[pitch < min_clip] = np.nan
|
29 |
+
pitch[~np.isnan(pitch)] = (pitch[~np.isnan(pitch)] - 1) * pitch_downsample
|
30 |
+
pitch[~np.isnan(pitch)] = pitch[~np.isnan(pitch)] + min_norm_pitch
|
31 |
+
pitch[~np.isnan(pitch)] = 440 * 2**(pitch[~np.isnan(pitch)] / 1200)
|
32 |
+
pitch[np.isnan(pitch)] = 0
|
33 |
+
|
34 |
+
return pitch, 200//time_downsample
|
35 |
+
|
36 |
+
def invert_tonic(tonic: Optional[int] = None,
|
37 |
+
min_norm_pitch: int = 0,
|
38 |
+
min_clip: int = 200,
|
39 |
+
pitch_downsample: int = 1,
|
40 |
+
):
|
41 |
+
tonic += min_clip
|
42 |
+
tonic = pitch_downsample * (tonic - 1)
|
43 |
+
tonic += min_norm_pitch
|
44 |
+
tonic = 440 * 2**(tonic / 1200)
|
45 |
+
|
46 |
+
return tonic
|
47 |
+
|
48 |
+
def load_processed_pitch(pitch,
|
49 |
+
audio_seq_len: int,
|
50 |
+
min_norm_pitch: int,
|
51 |
+
pitch_downsample: int,
|
52 |
+
min_clip: int,
|
53 |
+
max_clip: int,
|
54 |
+
):
|
55 |
+
# pdb.set_trace()
|
56 |
+
pitch = hz_to_cents(pitch, min_norm_pitch=min_norm_pitch, min_clip=min_clip, max_clip=max_clip, pitch_downsample=pitch_downsample, silence_token=min_clip-4)
|
57 |
+
pitch_inds = np.linspace(0, pitch.shape[0], num=audio_seq_len, endpoint=False)
|
58 |
+
pitch = np.interp(pitch_inds, np.arange(0, pitch.shape[0]), pitch)
|
59 |
+
return pitch
|
60 |
+
|
61 |
+
def load_pitch_model(config, ckpt, qt = None, prime_file=None, device='cuda'):
|
62 |
+
gin.parse_config_file(config)
|
63 |
+
model = UNet()
|
64 |
+
model.load_state_dict(torch.load(ckpt)['state_dict'])
|
65 |
+
model.to(device)
|
66 |
+
if qt is not None:
|
67 |
+
qt = joblib.load(qt)
|
68 |
+
if prime_file is not None:
|
69 |
+
with gin.config_scope('val'): # probably have to change this
|
70 |
+
with gin.unlock_config():
|
71 |
+
gin.bind_parameter('dataset.pitch_read_w_downsample.qt_transform', qt)
|
72 |
+
primes = np.load(prime_file, allow_pickle=True)['concatenated_array'][:, 0]
|
73 |
+
else:
|
74 |
+
primes = None
|
75 |
+
task_fn = None
|
76 |
+
task_fn = partial(pitch_read_w_downsample,
|
77 |
+
seq_len=None)
|
78 |
+
return model, qt, primes, task_fn
|
79 |
+
|
80 |
+
def load_audio_model(config, ckpt, qt = None, device='cuda'):
|
81 |
+
gin.parse_config_file(config)
|
82 |
+
model = UNetPitchConditioned() # there are no gin parameters for some reason
|
83 |
+
model.load_state_dict(torch.load(ckpt)['state_dict'])
|
84 |
+
model.to(device)
|
85 |
+
if qt is not None:
|
86 |
+
qt = joblib.load(qt)
|
87 |
+
|
88 |
+
return model, qt
|
src/pitch_to_audio_utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import librosa as li
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import gin
|
7 |
+
import logging
|
8 |
+
|
9 |
+
import pdb
|
10 |
+
|
11 |
+
@gin.configurable
|
12 |
+
def torch_stft(x, nfft):
|
13 |
+
window = torch.hann_window(nfft).to(x)
|
14 |
+
x = torch.stft(
|
15 |
+
x,
|
16 |
+
n_fft=nfft,
|
17 |
+
hop_length=nfft // 4,
|
18 |
+
win_length=nfft,
|
19 |
+
window=window,
|
20 |
+
center=True,
|
21 |
+
return_complex=True,
|
22 |
+
)
|
23 |
+
x = 2 * x / torch.mean(window)
|
24 |
+
return x
|
25 |
+
|
26 |
+
@gin.configurable
|
27 |
+
def torch_istft(x, nfft):
|
28 |
+
# pdb.set_trace()
|
29 |
+
window = torch.hann_window(nfft).to(x.device)
|
30 |
+
x = x / 2 * torch.mean(window)
|
31 |
+
return torch.istft(
|
32 |
+
x,
|
33 |
+
n_fft=nfft,
|
34 |
+
hop_length=nfft // 4,
|
35 |
+
win_length=nfft,
|
36 |
+
window=window,
|
37 |
+
center=True,
|
38 |
+
)
|
39 |
+
|
40 |
+
@gin.configurable
|
41 |
+
def to_mels(stft, nfft, num_mels, sr, eps=1e-2):
|
42 |
+
mels = li.filters.mel(
|
43 |
+
sr=sr,
|
44 |
+
n_fft=nfft,
|
45 |
+
n_mels=num_mels,
|
46 |
+
fmin=40,
|
47 |
+
)
|
48 |
+
# pdb.set_trace()
|
49 |
+
mels = torch.from_numpy(mels).to(stft)
|
50 |
+
mel_stft = torch.einsum("mf,bft->bmt", mels, stft)
|
51 |
+
mel_stft = torch.log(mel_stft + eps)
|
52 |
+
return mel_stft
|
53 |
+
|
54 |
+
@gin.configurable
|
55 |
+
def from_mels(mel_stft, nfft, num_mels, sr, eps=1e-2):
|
56 |
+
mels = li.filters.mel(
|
57 |
+
sr=sr,
|
58 |
+
n_fft=nfft,
|
59 |
+
n_mels=num_mels,
|
60 |
+
fmin=40,
|
61 |
+
)
|
62 |
+
mels = torch.from_numpy(mels).to(mel_stft)
|
63 |
+
mels = torch.pinverse(mels)
|
64 |
+
mel_stft = torch.exp(mel_stft) - eps
|
65 |
+
stft = torch.einsum("fm,bmt->bft", mels, mel_stft)
|
66 |
+
return stft
|
67 |
+
|
68 |
+
@gin.configurable
|
69 |
+
def torch_gl(stft, nfft, sr, n_iter):
|
70 |
+
|
71 |
+
def _gl_iter(phase, xs, stft):
|
72 |
+
del xs
|
73 |
+
# pdb.set_trace()
|
74 |
+
c_stft = stft * torch.exp(1j * phase)
|
75 |
+
rec = torch_istft(c_stft, nfft)
|
76 |
+
r_stft = torch_stft(rec, nfft)
|
77 |
+
phase = torch.angle(r_stft)
|
78 |
+
return phase, None
|
79 |
+
|
80 |
+
phase = torch.rand_like(stft) * 2 * torch.pi
|
81 |
+
|
82 |
+
for _ in tqdm(range(n_iter)):
|
83 |
+
phase, _ = _gl_iter(phase, None, stft)
|
84 |
+
|
85 |
+
c_stft = stft * torch.exp(1j * phase)
|
86 |
+
audio = torch_istft(c_stft, nfft)
|
87 |
+
|
88 |
+
return audio
|
89 |
+
|
90 |
+
@gin.configurable
|
91 |
+
def normalize(x, qt=None):
|
92 |
+
x_flat = x.reshape(-1, 1)
|
93 |
+
if qt is None:
|
94 |
+
logging.warning('No quantile transformer found, returning input')
|
95 |
+
return x
|
96 |
+
return torch.Tensor(qt.transform(x_flat).reshape(x.shape))
|
97 |
+
|
98 |
+
@gin.configurable
|
99 |
+
def unnormalize(x, qt=None):
|
100 |
+
x_flat = x.reshape(-1, 1)
|
101 |
+
if qt is None:
|
102 |
+
logging.warning('No quantile transformer found, returning input')
|
103 |
+
return x
|
104 |
+
if isinstance(x_flat, torch.Tensor):
|
105 |
+
x_flat = x_flat.detach().cpu().numpy()
|
106 |
+
return torch.Tensor(qt.inverse_transform(x_flat).reshape(x.shape))
|
107 |
+
|
108 |
+
@gin.configurable
|
109 |
+
def audio_to_normalized_mels(x, nfft, num_mels, sr, qt):
|
110 |
+
# pdb.set_trace()
|
111 |
+
stfts = torch_stft(x, nfft=nfft).abs()[..., :-1]
|
112 |
+
mel_stfts = to_mels(stfts, nfft, num_mels, sr)
|
113 |
+
return normalize(mel_stfts, qt).to(x)
|
114 |
+
|
115 |
+
@gin.configurable
|
116 |
+
def normalized_mels_to_audio(x, nfft, num_mels, sr, qt, n_iter=20):
|
117 |
+
x = unnormalize(x, qt).to(x)
|
118 |
+
x = from_mels(x, nfft, num_mels, sr)
|
119 |
+
x = torch.clamp(x, 0, nfft)
|
120 |
+
x = torch_gl(x, nfft, sr, n_iter=n_iter)
|
121 |
+
return x
|