|
import os |
|
import glob |
|
import yaml |
|
import torch |
|
import __main__ |
|
import numpy as np |
|
import soundfile as sf |
|
import librosa |
|
import librosa.display |
|
import matplotlib.pyplot as plt |
|
|
|
import spaces |
|
import gradio as gr |
|
|
|
from src.model.nn.synthesizer import Synthesizer |
|
from src.utils.misc import triangular, downsample |
|
from src.utils.plot import state_video as plot_state_video |
|
from src.utils.audio import mel_basis, state_to_wav |
|
from src.utils.control import vibrato as control_vibrato |
|
|
|
class ConfigArgument: |
|
def __getitem__(self,key): |
|
return getattr(self, key) |
|
def __setitem__(self,key,value): |
|
return setattr(self, key, value) |
|
setattr(__main__, "ConfigArgument", ConfigArgument) |
|
|
|
def filter_state_dict(ckpt): |
|
out_dict = {} |
|
for key in ckpt.keys(): |
|
new_key = key[6:] if str(key)[:6] == 'model.' else key |
|
out_dict[new_key] = ckpt[key] |
|
return out_dict |
|
|
|
def flush(directory): |
|
os.makedirs(directory, exist_ok=True) |
|
files = glob.glob(f'{directory}/*') |
|
for f in files: |
|
os.remove(f) |
|
|
|
def add_glissando(f_0, Nt, sr, glissando, max_t): |
|
front = int(0.2 * np.random.rand() * sr * max_t) |
|
rear = int((0.2 * np.random.rand() + 0.3) * sr * max_t) |
|
middle = max(0, len(f_0) - front - rear) |
|
ramp = glissando * torch.cat((torch.zeros(front), torch.linspace(0,1,middle), torch.ones(rear)), dim=-1) |
|
return f_0 * (1 + ramp) |
|
|
|
def plot_spectrogram(path, x, n_fft=2048, hop_length=512, n_mel=256, samplerate=48000, max_duration=1): |
|
x_wave = np.zeros(int(max_duration * samplerate)) |
|
x_wave[:len(x)] += x |
|
x_spec = librosa.stft( |
|
x_wave, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, pad_mode='reflect') |
|
mag = np.abs(x_spec) |
|
mel_fbank = mel_basis(samplerate, n_fft, n_mel) |
|
mel = np.einsum('ij,jk->ik', mel_fbank, mag) |
|
|
|
plt.figure(figsize=(7,7)) |
|
librosa.display.specshow(mel) |
|
plt.xticks([]) |
|
plt.yticks([]) |
|
plt.clim([0, 30]) |
|
plt.tight_layout() |
|
plt.savefig(path, transparent=True) |
|
plt.close('all') |
|
plt.clf() |
|
|
|
with open("ckpt/config.yaml") as stream: |
|
configs = yaml.safe_load(stream) |
|
|
|
with open("ckpt/pitch.yaml") as stream: |
|
pitch_dict = yaml.safe_load(stream) |
|
|
|
def get_data(duration, resolution, note, glissando, vibrato, stiffness, tension, pluck, amplitude): |
|
sr = configs['sr'] |
|
Nt = int(duration * sr) |
|
Nx = int(resolution) |
|
|
|
xgrid = torch.linspace(0,1,Nx) |
|
tgrid = torch.arange(Nt) / sr |
|
pitch = pitch_dict[note] |
|
|
|
t60_min_1=20.; t60_max_1=30.; t60_min_2=30.; t60_max_2=30. |
|
t60_diff_max=5. |
|
T60 = torch.Tensor([[[1000., 25.],[100., 30.]]]) |
|
|
|
Nw = int(Nt / configs['block_size']) + 1 |
|
|
|
xg, tg = torch.meshgrid(xgrid, tgrid, indexing='ij') |
|
ka = torch.Tensor([stiffness]).view(-1,1) |
|
al = torch.Tensor([tension]).view(-1,1) |
|
f_0 = torch.ones(Nt) * pitch |
|
nx = torch.Tensor([[[Nx]]]).float() |
|
p_x = torch.ones_like(nx) * pluck |
|
p_a = torch.ones_like(nx) * amplitude |
|
u_0 = triangular(Nx, nx, p_x, p_a) |
|
|
|
f_0 = add_glissando(f_0, Nt, sr, glissando, Nt / sr) |
|
f_0 = f_0 + control_vibrato(f_0.view(1,-1), 1/sr, mf=[3.,5.], ma=vibrato) |
|
f_0 = downsample(f_0, factor=configs['block_size']) |
|
|
|
xg = xg[:,0].view(-1,1) |
|
tg = tg |
|
ka = ka.repeat(Nx,1) |
|
al = al.repeat(Nx,1) |
|
T60 = T60 |
|
f_0 = f_0.repeat(Nx,1) |
|
u_0 = u_0.repeat(Nx,1,1) |
|
|
|
params = [xg, tg, ka, al, T60, None, None] |
|
return params, f_0, u_0 |
|
|
|
@spaces.GPU |
|
def run(duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude): |
|
checkpoint = torch.load('ckpt/dmsp.ckpt', map_location='cpu') |
|
checkpoint = filter_state_dict(checkpoint['state_dict']) |
|
model = Synthesizer(**configs) |
|
model.load_state_dict(checkpoint) |
|
if torch.cuda.is_available(): |
|
model = model.cuda() |
|
|
|
params, f_0, u_0 = get_data( \ |
|
duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude) |
|
|
|
if torch.cuda.is_available(): |
|
params = [p.cuda() if p is not None else p for p in params] |
|
f_0 = f_0.cuda() |
|
u_0 = u_0.cuda() |
|
|
|
with torch.no_grad(): |
|
ut, mode_input, mode_output = model(params, f_0, u_0) |
|
ut = ut.detach().cpu() |
|
ut_wave = configs['gain'] * ut.mean(0) |
|
|
|
save_dir = 'results' |
|
prefix = 'dmsp' |
|
fname = 'output' |
|
flush(save_dir) |
|
audio_name = f'{save_dir}/{fname}.wav' |
|
video_name = f'{save_dir}/{prefix}-{fname}.mp4' |
|
spec_name = f'{save_dir}/spec.png' |
|
|
|
ut = ut.numpy().T |
|
ut_wave = ut_wave.numpy() |
|
maxy = 0.022 |
|
sf.write(audio_name, ut_wave, samplerate=configs['sr']) |
|
plot_spectrogram(spec_name, ut_wave, samplerate=configs['sr']) |
|
plot_state_video(save_dir, ut, configs['sr'], prefix=prefix, fname=fname, maxy=maxy) |
|
return spec_name, video_name |
|
|
|
pitch_list = ["G2", "Ab2", "A2", "Bb2", "B2", "C3", "Db3", "D3", "Eb3", "E3", "F3", "Gb3", "G3", "Ab3", "A3", "Bb3", "B3", "C4", "Db4", "D4", "Eb4", "E4", "F4", "Gb4", "G4",] |
|
|
|
duration = gr.Slider(0.1, 1.0, value=1.0, label="Temporal Duration") |
|
resolution = gr.Slider(128, 256, value=256, label="Spatial Resolution", info='Reduce to simulate faster. Recommended to leave it as 256.') |
|
pitch = gr.Dropdown(pitch_list, value="C3", label="Pitch", info="Specify the fundamental frequency as a musical note.") |
|
glissando = gr.Slider(-0.4, 0.4, value=0, label="Glissando", info='Set +/- to ascend (+) or descend (-) the pitch') |
|
vibrato = gr.Slider(0, 0.25, value=0, label="Vibrato", info='Set larger value to add more vibrato') |
|
stiffness = gr.Slider(0.011, 0.029, value=0.02, label="Stiffness", info='Stiffness can change the resulting pitch. Specify low values when tension is high') |
|
tension = gr.Slider(1.0, 25, value=4, label="Stiffness-Tension Ratio", info='Tension can introduce non-linear effects such as pitch glide. Specify low values when stiffness is high') |
|
pluck = gr.Slider(0.12, 0.5, value=0.2, label="Plucking Position", info='Peak position of an initial condition') |
|
amplitude = gr.Slider(0.001, 0.02, value=0.015, label="Plucking Amplitude", info='Peak amplitude of an initial condition') |
|
|
|
demo = gr.Interface( |
|
fn=run, |
|
inputs=[ |
|
duration, resolution, pitch, glissando, vibrato, |
|
stiffness, tension, pluck, amplitude, |
|
], |
|
outputs=[ |
|
gr.Image(), |
|
gr.Video(format='mp4', include_audio=True), |
|
], |
|
) |
|
demo.launch() |
|
|
|
|
|
|