first commit
Browse files- .gitattributes +1 -0
- .gitignore +17 -0
- README.md +1 -1
- app.py +166 -0
- ckpt/config.yaml +16 -0
- ckpt/dmsp.ckpt +3 -0
- ckpt/pitch.yaml +25 -0
- requirements.txt +5 -0
- src/model/nn/blocks.py +208 -0
- src/model/nn/ddsp.py +69 -0
- src/model/nn/dmsp.py +63 -0
- src/model/nn/synthesizer.py +125 -0
- src/utils/audio.py +219 -0
- src/utils/control.py +61 -0
- src/utils/ddsp.py +175 -0
- src/utils/misc.py +336 -0
- src/utils/plot.py +1132 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ckpt/dmsp.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.*.swp
|
2 |
+
.*.py.swp
|
3 |
+
*/.*.py.swp
|
4 |
+
*/*/.*.py.swp
|
5 |
+
*/*/*/.*.py.swp
|
6 |
+
__pycache__/
|
7 |
+
*/__pycache__/
|
8 |
+
*/*/__pycache__/
|
9 |
+
*/*/*/__pycache__/
|
10 |
+
|
11 |
+
src/*.py
|
12 |
+
src/configs
|
13 |
+
src/dataset
|
14 |
+
src/task
|
15 |
+
src/model/cpp
|
16 |
+
src/model/*.py
|
17 |
+
check.py
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: 😻
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
|
|
1 |
---
|
2 |
+
title: dmsp
|
3 |
emoji: 😻
|
4 |
colorFrom: indigo
|
5 |
colorTo: red
|
app.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import __main__
|
6 |
+
import numpy as np
|
7 |
+
import soundfile as sf
|
8 |
+
import librosa
|
9 |
+
import librosa.display
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
|
14 |
+
from src.model.nn.synthesizer import Synthesizer
|
15 |
+
from src.utils.misc import triangular, downsample
|
16 |
+
from src.utils.plot import state_video as plot_state_video
|
17 |
+
from src.utils.audio import mel_basis, state_to_wav
|
18 |
+
from src.utils.control import vibrato as control_vibrato
|
19 |
+
|
20 |
+
class ConfigArgument:
|
21 |
+
def __getitem__(self,key):
|
22 |
+
return getattr(self, key)
|
23 |
+
def __setitem__(self,key,value):
|
24 |
+
return setattr(self, key, value)
|
25 |
+
setattr(__main__, "ConfigArgument", ConfigArgument)
|
26 |
+
|
27 |
+
def filter_state_dict(ckpt):
|
28 |
+
out_dict = {}
|
29 |
+
for key in ckpt.keys():
|
30 |
+
new_key = key[6:] if str(key)[:6] == 'model.' else key
|
31 |
+
out_dict[new_key] = ckpt[key]
|
32 |
+
return out_dict
|
33 |
+
|
34 |
+
def flush(directory):
|
35 |
+
os.makedirs(directory, exist_ok=True)
|
36 |
+
files = glob.glob(f'{directory}/*')
|
37 |
+
for f in files:
|
38 |
+
os.remove(f)
|
39 |
+
|
40 |
+
def add_glissando(f_0, Nt, sr, glissando, max_t):
|
41 |
+
front = int(0.2 * np.random.rand() * sr * max_t)
|
42 |
+
rear = int((0.2 * np.random.rand() + 0.3) * sr * max_t)
|
43 |
+
middle = max(0, len(f_0) - front - rear)
|
44 |
+
ramp = glissando * torch.cat((torch.zeros(front), torch.linspace(0,1,middle), torch.ones(rear)), dim=-1)
|
45 |
+
return f_0 * (1 + ramp)
|
46 |
+
|
47 |
+
def plot_spectrogram(path, x, n_fft=2048, hop_length=512, n_mel=256, samplerate=48000, max_duration=1):
|
48 |
+
x_wave = np.zeros(int(max_duration * samplerate))
|
49 |
+
x_wave[:len(x)] += x
|
50 |
+
x_spec = librosa.stft(
|
51 |
+
x_wave, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, pad_mode='reflect')
|
52 |
+
mag = np.abs(x_spec) # (n_frames, n_freq)
|
53 |
+
mel_fbank = mel_basis(samplerate, n_fft, n_mel) # (n_mel, n_freq)
|
54 |
+
mel = np.einsum('ij,jk->ik', mel_fbank, mag) # (n_frames, n_mel)
|
55 |
+
|
56 |
+
plt.figure(figsize=(7,7))
|
57 |
+
librosa.display.specshow(mel)
|
58 |
+
plt.xticks([])
|
59 |
+
plt.yticks([])
|
60 |
+
plt.clim([0, 30])
|
61 |
+
plt.tight_layout()
|
62 |
+
plt.savefig(path, transparent=True)
|
63 |
+
plt.close('all')
|
64 |
+
plt.clf()
|
65 |
+
|
66 |
+
with open("ckpt/config.yaml") as stream:
|
67 |
+
configs = yaml.safe_load(stream)
|
68 |
+
|
69 |
+
with open("ckpt/pitch.yaml") as stream:
|
70 |
+
pitch_dict = yaml.safe_load(stream)
|
71 |
+
|
72 |
+
def get_data(duration, resolution, note, glissando, vibrato, stiffness, tension, pluck, amplitude):
|
73 |
+
sr = configs['sr']
|
74 |
+
Nt = int(duration * sr)
|
75 |
+
Nx = int(resolution)
|
76 |
+
|
77 |
+
xgrid = torch.linspace(0,1,Nx)
|
78 |
+
tgrid = torch.arange(Nt) / sr
|
79 |
+
pitch = pitch_dict[note]
|
80 |
+
|
81 |
+
t60_min_1=20.; t60_max_1=30.; t60_min_2=30.; t60_max_2=30.
|
82 |
+
t60_diff_max=5.
|
83 |
+
T60 = torch.Tensor([[[1000., 25.],[100., 30.]]])
|
84 |
+
|
85 |
+
Nw = int(Nt / configs['block_size']) + 1
|
86 |
+
|
87 |
+
xg, tg = torch.meshgrid(xgrid, tgrid, indexing='ij')
|
88 |
+
ka = torch.Tensor([stiffness]).view(-1,1) # (1,1)
|
89 |
+
al = torch.Tensor([tension]).view(-1,1) # (1,1)
|
90 |
+
f_0 = torch.ones(Nt) * pitch # (Nt)
|
91 |
+
nx = torch.Tensor([[[Nx]]]).float()
|
92 |
+
p_x = torch.ones_like(nx) * pluck
|
93 |
+
p_a = torch.ones_like(nx) * amplitude
|
94 |
+
u_0 = triangular(Nx, nx, p_x, p_a) # (1, 1, Nx)
|
95 |
+
|
96 |
+
f_0 = add_glissando(f_0, Nt, sr, glissando, Nt / sr)
|
97 |
+
f_0 = f_0 + control_vibrato(f_0.view(1,-1), 1/sr, mf=[3.,5.], ma=vibrato)
|
98 |
+
f_0 = downsample(f_0, factor=configs['block_size'])
|
99 |
+
|
100 |
+
xg = xg[:,0].view(-1,1) # (Nx, 1)
|
101 |
+
tg = tg # (Nx, Nt)
|
102 |
+
ka = ka.repeat(Nx,1) # (Nx, 1)
|
103 |
+
al = al.repeat(Nx,1) # (Nx, 1)
|
104 |
+
T60 = T60 # (Nx, 1, 1)
|
105 |
+
f_0 = f_0.repeat(Nx,1) # (Nx, Nw)
|
106 |
+
u_0 = u_0.repeat(Nx,1,1) # (Nx, 1, Nx)
|
107 |
+
|
108 |
+
params = [xg, tg, ka, al, T60, None, None]
|
109 |
+
return params, f_0, u_0
|
110 |
+
|
111 |
+
def run(duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude):
|
112 |
+
checkpoint = torch.load('ckpt/dmsp.ckpt', map_location='cpu')
|
113 |
+
checkpoint = filter_state_dict(checkpoint['state_dict'])
|
114 |
+
model = Synthesizer(**configs)
|
115 |
+
model.load_state_dict(checkpoint)
|
116 |
+
|
117 |
+
params, f_0, u_0 = get_data( \
|
118 |
+
duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude)
|
119 |
+
|
120 |
+
with torch.no_grad():
|
121 |
+
ut, mode_input, mode_output = model(params, f_0, u_0)
|
122 |
+
ut = ut.detach() # (Nx, Nt)
|
123 |
+
ut_wave = configs['gain'] * ut.mean(0)
|
124 |
+
|
125 |
+
save_dir = 'results'
|
126 |
+
prefix = 'dmsp'
|
127 |
+
fname = 'output'
|
128 |
+
flush(save_dir)
|
129 |
+
audio_name = f'{save_dir}/{fname}.wav'
|
130 |
+
video_name = f'{save_dir}/{prefix}-{fname}.mp4'
|
131 |
+
spec_name = f'{save_dir}/spec.png'
|
132 |
+
|
133 |
+
ut = ut.numpy().T
|
134 |
+
ut_wave = ut_wave.numpy()
|
135 |
+
maxy = 0.022
|
136 |
+
sf.write(audio_name, ut_wave, samplerate=configs['sr'])
|
137 |
+
plot_spectrogram(spec_name, ut_wave, samplerate=configs['sr'])
|
138 |
+
plot_state_video(save_dir, ut, configs['sr'], prefix=prefix, fname=fname, maxy=maxy)
|
139 |
+
return spec_name, video_name
|
140 |
+
|
141 |
+
pitch_list = ["G2", "Ab2", "A2", "Bb2", "B2", "C3", "Db3", "D3", "Eb3", "E3", "F3", "Gb3", "G3", "Ab3", "A3", "Bb3", "B3", "C4", "Db4", "D4", "Eb4", "E4", "F4", "Gb4", "G4",]
|
142 |
+
|
143 |
+
duration = gr.Slider(0.1, 1.0, value=1.0, label="Time Duration")
|
144 |
+
resolution = gr.Slider(128, 256, value=256, label="Space Resolution", info='Reduce to simulate faster. Recommended to leave it as 256.')
|
145 |
+
pitch = gr.Dropdown(pitch_list, value="C3", label="Pitch", info="Specify the fundamental frequency as a musical note.")
|
146 |
+
glissando = gr.Slider(-0.4, 0.4, value=0, label="Glissando", info='Set +/- to ascend (+) or descend (-) the pitch')
|
147 |
+
vibrato = gr.Slider(0, 0.25, value=0, label="Vibrato", info='Set larger value to add more vibrato')
|
148 |
+
stiffness = gr.Slider(0.011, 0.029, value=0.02, label="Stiffness", info='Stiffness can change the resulting pitch. Specify low values when tension is high')
|
149 |
+
tension = gr.Slider(1.0, 25, value=4, label="Tension", info='Tension can introduce non-linear effects such as pitch glide. Specify low values when stiffness is high')
|
150 |
+
pluck = gr.Slider(0.12, 0.5, value=0.2, label="Pluck Position", info='Peak position of an initial condition')
|
151 |
+
amplitude = gr.Slider(0.001, 0.02, value=0.015, label="Pluck Amplitude", info='Peak amplitude of an initial condition')
|
152 |
+
|
153 |
+
demo = gr.Interface(
|
154 |
+
fn=run,
|
155 |
+
inputs=[
|
156 |
+
duration, resolution, pitch, glissando, vibrato,
|
157 |
+
stiffness, tension, pluck, amplitude,
|
158 |
+
],
|
159 |
+
outputs=[
|
160 |
+
gr.Image(),
|
161 |
+
gr.Video(format='mp4', include_audio=True),
|
162 |
+
],
|
163 |
+
)
|
164 |
+
demo.launch()
|
165 |
+
|
166 |
+
|
ckpt/config.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
n_modes: 40
|
2 |
+
n_bands: 65
|
3 |
+
embed_dim: 128
|
4 |
+
use_precomputed_mod: False
|
5 |
+
harmonic: 'inharmonic'
|
6 |
+
hidden_dim: 512
|
7 |
+
block_size: 256
|
8 |
+
sr: 48000
|
9 |
+
gain: 100
|
10 |
+
x_scale: [0., 1.]
|
11 |
+
t_scale: [0., .3]
|
12 |
+
gamma_scale: [196, 880]
|
13 |
+
kappa_scale: [.01, .03]
|
14 |
+
alpha_scale: [1., 30.]
|
15 |
+
sig_0_scale: [0., 0.7]
|
16 |
+
sig_1_scale: [0., 0.00001]
|
ckpt/dmsp.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:071f2f176f1ced82e287ee5d41ffbe36dc6f613e7e5b1dea4e19e9609f82655b
|
3 |
+
size 104471835
|
ckpt/pitch.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
G2 : 98.00
|
2 |
+
Ab2: 103.83
|
3 |
+
A2 : 110.00
|
4 |
+
Bb2: 116.54
|
5 |
+
B2 : 123.47
|
6 |
+
C3 : 130.81
|
7 |
+
Db3: 138.59
|
8 |
+
D3 : 146.83
|
9 |
+
Eb3: 155.56
|
10 |
+
E3 : 164.81
|
11 |
+
F3 : 174.61
|
12 |
+
Gb3: 185.00
|
13 |
+
G3 : 196.00
|
14 |
+
Ab3: 207.65
|
15 |
+
A3 : 220.00
|
16 |
+
Bb3: 233.08
|
17 |
+
B3 : 246.94
|
18 |
+
C4 : 261.63
|
19 |
+
Db4: 277.18
|
20 |
+
D4 : 293.66
|
21 |
+
Eb4: 311.13
|
22 |
+
E4 : 329.63
|
23 |
+
F4 : 349.23
|
24 |
+
Gb4: 369.99
|
25 |
+
G4 : 392.00
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.12.1
|
2 |
+
numpy==1.24.3
|
3 |
+
einops
|
4 |
+
librosa
|
5 |
+
omegaconf
|
src/model/nn/blocks.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange
|
7 |
+
|
8 |
+
from src.utils import misc as utils
|
9 |
+
|
10 |
+
def apply_gain(x, gain, fn=None):
|
11 |
+
gain = fn(gain) if fn is not None else gain
|
12 |
+
x_list = x.chunk(len(gain), -1)
|
13 |
+
x_list = [gain[i] * x_i for i, x_i in enumerate(x_list)]
|
14 |
+
return torch.cat(x_list, dim=-1)
|
15 |
+
|
16 |
+
class FMBlock(nn.Module):
|
17 |
+
def __init__(self, input_dim, embed_dim, num_features):
|
18 |
+
super().__init__()
|
19 |
+
concat_size = embed_dim * num_features + embed_dim
|
20 |
+
feature_dim = embed_dim * num_features
|
21 |
+
self.rff2 = RFF2(input_dim, embed_dim//2)
|
22 |
+
self.tmlp = mlp(concat_size, feature_dim, 5)
|
23 |
+
self.proj = nn.Linear(concat_size, 2*input_dim)
|
24 |
+
self.activation = nn.GLU(dim=-1)
|
25 |
+
|
26 |
+
gain_in = torch.randn(num_features) / 2
|
27 |
+
gain_out = torch.Tensor([0.1])
|
28 |
+
self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True))
|
29 |
+
self.register_parameter('gain_out', nn.Parameter(gain_out, requires_grad=True))
|
30 |
+
|
31 |
+
def forward(self, input, feature, slider, omega):
|
32 |
+
''' input : (B T input_dim)
|
33 |
+
feature: (B T feature_dim)
|
34 |
+
slider : (B T 1)
|
35 |
+
'''
|
36 |
+
_input = input / (1.3*math.pi) - 1
|
37 |
+
_input = self.rff2(_input)
|
38 |
+
feature = apply_gain(feature, self.gain_in, torch.tanh)
|
39 |
+
|
40 |
+
x = torch.cat((_input, feature), dim=-1)
|
41 |
+
x = torch.cat((self.tmlp(x), _input), dim=-1)
|
42 |
+
x = self.activation(self.proj(x))
|
43 |
+
|
44 |
+
gate = torch.tanh((slider - 1) * self.gain_out)
|
45 |
+
return input + omega * x * gate
|
46 |
+
|
47 |
+
class AMBlock(nn.Module):
|
48 |
+
def __init__(self, input_dim, embed_dim, num_features):
|
49 |
+
super().__init__()
|
50 |
+
concat_size = embed_dim * num_features + embed_dim
|
51 |
+
feature_dim = embed_dim * num_features
|
52 |
+
self.rff2 = RFF2(input_dim, embed_dim//2)
|
53 |
+
self.tmlp = mlp(concat_size, feature_dim, 5)
|
54 |
+
self.proj = nn.Linear(concat_size, 2*input_dim)
|
55 |
+
self.activation = nn.GLU(dim=-1)
|
56 |
+
|
57 |
+
gain_in = torch.randn(num_features) / 2
|
58 |
+
self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True))
|
59 |
+
|
60 |
+
def forward(self, input, feature, slider):
|
61 |
+
''' input : (B T input_dim)
|
62 |
+
feature: (B T feature_dim)
|
63 |
+
slider : (B T 1)
|
64 |
+
'''
|
65 |
+
_input = input * 110 - 0.55
|
66 |
+
_input = self.rff2(_input)
|
67 |
+
feature = apply_gain(feature, self.gain_in, torch.tanh)
|
68 |
+
|
69 |
+
x = torch.cat((_input, feature), dim=-1)
|
70 |
+
x = torch.cat((self.tmlp(x), _input), dim=-1)
|
71 |
+
x = self.activation(self.proj(x))
|
72 |
+
|
73 |
+
return input * (1 + x)
|
74 |
+
|
75 |
+
class ModBlock(nn.Module):
|
76 |
+
def __init__(self, input_dim, feature_dim, embed_dim):
|
77 |
+
super().__init__()
|
78 |
+
cat_size = 1+feature_dim
|
79 |
+
self.tmlp = mlp(cat_size, feature_dim, 2)
|
80 |
+
self.proj = nn.Linear(cat_size, 2)
|
81 |
+
self.activation = nn.GLU(dim=-1)
|
82 |
+
|
83 |
+
def forward(self, input, feature, slider):
|
84 |
+
''' input : (B T input_dim)
|
85 |
+
feature: (B T feature_dim)
|
86 |
+
slider : (B T 1)
|
87 |
+
'''
|
88 |
+
input = input.unsqueeze(-1) # (B T input_dim 1)
|
89 |
+
feature = feature.unsqueeze(-2).repeat(1,1,input.size(-2),1)
|
90 |
+
x = torch.cat((input, feature), dim=-1)
|
91 |
+
x = torch.cat((self.tmlp(x), input), dim=-1)
|
92 |
+
x = self.activation(self.proj(x))
|
93 |
+
return (input * (1 + x)).squeeze(-1)
|
94 |
+
|
95 |
+
def mlp(in_size, hidden_size, n_layers):
|
96 |
+
channels = [in_size] + (n_layers) * [hidden_size]
|
97 |
+
net = []
|
98 |
+
for i in range(n_layers):
|
99 |
+
net.append(nn.Linear(channels[i], channels[i + 1]))
|
100 |
+
#net.append(nn.LayerNorm(channels[i + 1]))
|
101 |
+
net.append(nn.PReLU())
|
102 |
+
return nn.Sequential(*net)
|
103 |
+
|
104 |
+
|
105 |
+
class RFF2(nn.Module):
|
106 |
+
""" Random Fourier Features Module """
|
107 |
+
def __init__(self, input_dim, embed_dim, scale=1.):
|
108 |
+
super().__init__()
|
109 |
+
#N = torch.randn(input_dim, embed_dim)
|
110 |
+
N = torch.ones((input_dim, embed_dim)) / input_dim / embed_dim
|
111 |
+
N = nn.Parameter(N, requires_grad=False)
|
112 |
+
e = torch.Tensor([scale])
|
113 |
+
e = nn.Parameter(e, requires_grad=True)
|
114 |
+
self.register_buffer('N', N)
|
115 |
+
self.register_parameter('e', e)
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
''' x: (Bs, Nt, input_dim)
|
119 |
+
-> (Bs, Nt, embed_dim)
|
120 |
+
'''
|
121 |
+
B = self.e * self.N
|
122 |
+
x_embd = utils.fourier_feature(x, B)
|
123 |
+
return x_embd
|
124 |
+
|
125 |
+
|
126 |
+
class RFF(nn.Module):
|
127 |
+
""" Random Fourier Features Module """
|
128 |
+
def __init__(self, scales, embed_dim):
|
129 |
+
super().__init__()
|
130 |
+
input_dim = len(scales)
|
131 |
+
N = torch.randn(input_dim, embed_dim)
|
132 |
+
N = nn.Parameter(N, requires_grad=False)
|
133 |
+
e = torch.Tensor(scales).view(-1,1)
|
134 |
+
e = nn.Parameter(e, requires_grad=True)
|
135 |
+
self.register_buffer('N', N)
|
136 |
+
self.register_parameter('e', e)
|
137 |
+
|
138 |
+
def forward(self, x):
|
139 |
+
''' x: (Bs, Nt, input_dim)
|
140 |
+
-> (Bs, Nt, input_dim*embed_dim)
|
141 |
+
'''
|
142 |
+
xs = x.chunk(self.N.size(0), -1) # (Bs, Nt, 1) * input_dim
|
143 |
+
Ns = self.N.chunk(self.N.size(0), 0) # (1, embed_dim) * input_dim
|
144 |
+
Bs = [torch.pow(10, self.e[i]) * N for i, N in enumerate(Ns)]
|
145 |
+
x_embd = [utils.fourier_feature(xs[i], B) for i, B in enumerate(Bs)]
|
146 |
+
return torch.cat(x_embd, dim=-1)
|
147 |
+
|
148 |
+
|
149 |
+
class ModeEstimator(nn.Module):
|
150 |
+
def __init__(self, n_modes, hidden_dim, kappa_scale=None, gamma_scale=None, inharmonic=True, sr=48000):
|
151 |
+
super().__init__()
|
152 |
+
self.sr = sr
|
153 |
+
self.kappa_scale = kappa_scale
|
154 |
+
self.gamma_scale = gamma_scale
|
155 |
+
self.rff = RFF([1.]*5, hidden_dim//2)
|
156 |
+
self.a_mlp = mlp(5*hidden_dim, hidden_dim, 2)
|
157 |
+
self.a_proj = nn.Linear(hidden_dim, n_modes)
|
158 |
+
self.tanh = nn.Tanh()
|
159 |
+
if inharmonic:
|
160 |
+
self.f_mlp = mlp(5*hidden_dim, hidden_dim, 2)
|
161 |
+
self.f_proj = nn.Linear(hidden_dim, n_modes)
|
162 |
+
self.sigmoid = nn.Sigmoid()
|
163 |
+
else:
|
164 |
+
self.f_mlp = None
|
165 |
+
self.f_proj = None
|
166 |
+
self.sigmoid = nn.Sigmoid()
|
167 |
+
|
168 |
+
def forward(self, u_0, x_p, kappa, gamma):
|
169 |
+
''' u_0 : (b, 1, x)
|
170 |
+
x_p : (b, 1, 1)
|
171 |
+
kappa : (b, 1, 1)
|
172 |
+
gamma : (b, 1, 1)
|
173 |
+
'''
|
174 |
+
p_x = torch.argmax(u_0, dim=-1, keepdim=True) / 255. # (b, 1, 1)
|
175 |
+
p_a = torch.max(u_0, dim=-1, keepdim=True).values / 0.02 # (b, 1, 1)
|
176 |
+
kappa = self.normalize_kappa(kappa)
|
177 |
+
gamma = self.normalize_gamma(gamma)
|
178 |
+
con = torch.cat((p_x, p_a, x_p, kappa, gamma), dim=-1) # (b, 1, 5)
|
179 |
+
con = self.rff(con) # (b, 1, 3*hidden_dim)
|
180 |
+
|
181 |
+
mode_amps = self.a_mlp(con) # (b, 1, k)
|
182 |
+
mode_amps = self.tanh(1e-3 * self.a_proj(mode_amps)) # (b, 1, m)
|
183 |
+
|
184 |
+
if self.f_mlp is not None:
|
185 |
+
mode_freq = self.f_mlp(con) # (b, 1, k)
|
186 |
+
mode_freq = 0.3 * self.sigmoid(self.f_proj(mode_freq)) # (b, 1, m)
|
187 |
+
mode_freq = mode_freq.cumsum(-1)
|
188 |
+
else:
|
189 |
+
int_mults = torch.ones_like(mode_amps).cumsum(-1) # (b, 1, k)
|
190 |
+
omega = gamma / self.sr * (2*math.pi)
|
191 |
+
mode_freq = omega * int_mults
|
192 |
+
|
193 |
+
return mode_amps, mode_freq
|
194 |
+
|
195 |
+
def normalize_gamma(self, x):
|
196 |
+
if self.gamma_scale is not None:
|
197 |
+
minval = min(self.gamma_scale)
|
198 |
+
denval = max(self.gamma_scale) - minval
|
199 |
+
x = (x - minval) / denval
|
200 |
+
return x
|
201 |
+
|
202 |
+
def normalize_kappa(self, x):
|
203 |
+
if self.kappa_scale is not None:
|
204 |
+
minval = min(self.kappa_scale)
|
205 |
+
denval = max(self.kappa_scale) - minval
|
206 |
+
x = (x - minval) / denval
|
207 |
+
return x
|
208 |
+
|
src/model/nn/ddsp.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.model.nn.blocks import FMBlock, AMBlock
|
4 |
+
from src.utils.ddsp import upsample
|
5 |
+
from src.utils.ddsp import remove_above_nyquist_mode
|
6 |
+
from src.utils.ddsp import amp_to_impulse_response, fft_convolve
|
7 |
+
from src.utils.ddsp import modal_synth
|
8 |
+
from src.utils.ddsp import resample
|
9 |
+
import math
|
10 |
+
|
11 |
+
class DDSP(nn.Module):
|
12 |
+
def __init__(self,
|
13 |
+
feature_size, hidden_size,
|
14 |
+
n_modes, n_bands, sampling_rate, block_size,
|
15 |
+
fm=False,
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
self.n_modes = n_modes
|
19 |
+
|
20 |
+
self.freq_modulator = FMBlock(n_modes, feature_size) if fm else None
|
21 |
+
self.coef_modulator = AMBlock(n_modes, feature_size)
|
22 |
+
self.noise_proj = nn.Linear(feature_size, n_bands)
|
23 |
+
|
24 |
+
noise_gate = nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
|
25 |
+
self.register_parameter("noise_gate", noise_gate)
|
26 |
+
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
27 |
+
self.register_buffer("block_size", torch.tensor(block_size))
|
28 |
+
|
29 |
+
def forward(self, hidden, mode_freq, mode_coef, times, alpha, lengths):
|
30 |
+
''' hidden : (Bs, 1, hidden_size)
|
31 |
+
mode_freq : (Bs, Nt, n_modes)
|
32 |
+
mode_coef : (Bs, 1, n_modes)
|
33 |
+
times : (Bs, Nt, 1)
|
34 |
+
'''
|
35 |
+
if self.freq_modulator is None:
|
36 |
+
freq_m = mode_freq # integer multiples
|
37 |
+
else:
|
38 |
+
freq_m = self.freq_modulator(mode_freq, hidden)
|
39 |
+
coef_m = self.coef_modulator(mode_coef, hidden, times)
|
40 |
+
|
41 |
+
#==============================
|
42 |
+
# harmonic part
|
43 |
+
#==============================
|
44 |
+
freqs = freq_m / (2*math.pi) * self.sampling_rate
|
45 |
+
coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes)
|
46 |
+
freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths)
|
47 |
+
coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths)
|
48 |
+
harmonic = modal_synth(freq_s, coef_s, self.sampling_rate)
|
49 |
+
|
50 |
+
#==============================
|
51 |
+
# noise part
|
52 |
+
#==============================
|
53 |
+
ngate = torch.tanh((alpha - 1) * self.noise_gate)
|
54 |
+
param = ngate * torch.sigmoid(self.noise_proj(hidden) - 5)
|
55 |
+
|
56 |
+
impulse = amp_to_impulse_response(param, self.block_size)
|
57 |
+
noise = torch.rand(
|
58 |
+
impulse.shape[0],
|
59 |
+
impulse.shape[1],
|
60 |
+
self.block_size,
|
61 |
+
).to(impulse) * 2 - 1
|
62 |
+
noise = fft_convolve(noise, impulse).contiguous()
|
63 |
+
noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths)
|
64 |
+
|
65 |
+
signal = harmonic + noise
|
66 |
+
return signal.squeeze(-1), freq_m, coef_m
|
67 |
+
|
68 |
+
|
69 |
+
|
src/model/nn/dmsp.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.model.nn.blocks import FMBlock, AMBlock, ModBlock
|
4 |
+
from src.utils.ddsp import scale_function, remove_above_nyquist, upsample
|
5 |
+
from src.utils.ddsp import remove_above_nyquist_mode
|
6 |
+
from src.utils.ddsp import harmonic_synth, amp_to_impulse_response, fft_convolve
|
7 |
+
from src.utils.ddsp import modal_synth
|
8 |
+
from src.utils.ddsp import resample
|
9 |
+
import math
|
10 |
+
|
11 |
+
class DMSP(nn.Module):
|
12 |
+
def __init__(self,
|
13 |
+
embed_dim, hidden_size, n_features,
|
14 |
+
n_modes, n_bands, sampling_rate, block_size,
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.n_modes = n_modes
|
18 |
+
|
19 |
+
self.freq_modulator = FMBlock(n_modes, embed_dim, n_features)
|
20 |
+
self.coef_modulator = AMBlock(n_modes, embed_dim, n_features)
|
21 |
+
self.proj_noise = nn.Linear(n_features*embed_dim, n_bands)
|
22 |
+
|
23 |
+
self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
|
24 |
+
self.register_buffer("block_size", torch.tensor(block_size))
|
25 |
+
|
26 |
+
def forward(self, hidden, mode_freq, mode_coef, times, alpha, omega, lengths):
|
27 |
+
''' hidden : (Bs, 1, hidden_size)
|
28 |
+
mode_freq : (Bs, Nt, n_modes)
|
29 |
+
mode_coef : (Bs, 1, n_modes)
|
30 |
+
times : (Bs, Nt, 1)
|
31 |
+
'''
|
32 |
+
freq_m = self.freq_modulator(mode_freq, hidden, alpha, omega)
|
33 |
+
coef_m = self.coef_modulator(mode_coef, hidden, times)
|
34 |
+
|
35 |
+
#==============================
|
36 |
+
# harmonic part
|
37 |
+
#==============================
|
38 |
+
freqs = freq_m / (2*math.pi) * self.sampling_rate
|
39 |
+
coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes)
|
40 |
+
freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths)
|
41 |
+
coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths)
|
42 |
+
harmonic = modal_synth(freq_s, coef_s, self.sampling_rate)
|
43 |
+
|
44 |
+
#==============================
|
45 |
+
# noise part
|
46 |
+
#==============================
|
47 |
+
param = scale_function(self.proj_noise(hidden) - 5)
|
48 |
+
|
49 |
+
impulse = amp_to_impulse_response(param, self.block_size)
|
50 |
+
noise = torch.rand(
|
51 |
+
impulse.shape[0],
|
52 |
+
impulse.shape[1],
|
53 |
+
self.block_size,
|
54 |
+
).to(impulse) * 2 - 1
|
55 |
+
noise = fft_convolve(noise, impulse).contiguous()
|
56 |
+
noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths)
|
57 |
+
|
58 |
+
signal = harmonic + noise
|
59 |
+
return signal.squeeze(-1), freq_m, coef_m
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
src/model/nn/synthesizer.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from src.utils import audio as audio
|
8 |
+
|
9 |
+
class Synthesizer(nn.Module):
|
10 |
+
""" Synthesizer Network """
|
11 |
+
def __init__(self,
|
12 |
+
embed_dim=64,
|
13 |
+
x_scale=1, t_scale=1,
|
14 |
+
gamma_scale=0, kappa_scale=0, alpha_scale=0, sig_0_scale=0, sig_1_scale=0,
|
15 |
+
**kwargs):
|
16 |
+
super().__init__()
|
17 |
+
self.sr=kwargs['sr']
|
18 |
+
hidden_dim=kwargs['hidden_dim']
|
19 |
+
self.n_modes = kwargs['n_modes']
|
20 |
+
inharmonic = kwargs['harmonic'].lower() == 'inharmonic'
|
21 |
+
|
22 |
+
self.x_scale = x_scale
|
23 |
+
self.t_scale = t_scale
|
24 |
+
self.gamma_scale = gamma_scale
|
25 |
+
self.kappa_scale = kappa_scale
|
26 |
+
self.alpha_scale = alpha_scale
|
27 |
+
self.sig_0_scale = sig_0_scale
|
28 |
+
self.sig_1_scale = sig_1_scale
|
29 |
+
|
30 |
+
from src.model.nn.blocks import RFF, ModeEstimator
|
31 |
+
n_feats = 7
|
32 |
+
self.material_encoder = RFF([1.]*n_feats, embed_dim // 2)
|
33 |
+
feature_size = embed_dim * n_feats
|
34 |
+
self.mode_estimator = ModeEstimator(
|
35 |
+
self.n_modes, embed_dim, kappa_scale, gamma_scale,
|
36 |
+
inharmonic=inharmonic,
|
37 |
+
)
|
38 |
+
if inharmonic:
|
39 |
+
from src.model.nn.dmsp import DMSP
|
40 |
+
self.net = DMSP(
|
41 |
+
embed_dim=embed_dim,
|
42 |
+
hidden_size=hidden_dim,
|
43 |
+
n_features=n_feats,
|
44 |
+
n_modes=kwargs['n_modes'],
|
45 |
+
n_bands=kwargs['n_bands'],
|
46 |
+
block_size=kwargs['block_size'],
|
47 |
+
sampling_rate=kwargs['sr'],
|
48 |
+
)
|
49 |
+
else:
|
50 |
+
from src.model.nn.ddsp import DDSP
|
51 |
+
self.net = DDSP(
|
52 |
+
feature_size=feature_size,
|
53 |
+
hidden_size=hidden_dim,
|
54 |
+
n_modes=kwargs['n_modes'],
|
55 |
+
n_bands=kwargs['n_bands'],
|
56 |
+
block_size=kwargs['block_size'],
|
57 |
+
sampling_rate=kwargs['sr'],
|
58 |
+
fm=kwargs['ddsp_frequency_modulation'],
|
59 |
+
)
|
60 |
+
|
61 |
+
def forward(self, params, pitch, initial):
|
62 |
+
''' params : input parameters
|
63 |
+
pitch : fundamental frequency in Hz
|
64 |
+
initial: initial condition
|
65 |
+
'''
|
66 |
+
space, times, kappa, alpha, t60, mode_freq, mode_coef = params
|
67 |
+
|
68 |
+
f_0 = pitch.unsqueeze(2) # (b, frames, 1)
|
69 |
+
times = times.unsqueeze(-1) # (b, sample, 1)
|
70 |
+
kappa = kappa.unsqueeze(-1) # (b, 1, 1)
|
71 |
+
alpha = alpha.unsqueeze(-1) # (b, 1, 1)
|
72 |
+
space = space.unsqueeze(-1) # (b, 1, 1)
|
73 |
+
gamma = 2*f_0 # (b, frames, 1)
|
74 |
+
omega = f_0 / self.sr * (2*math.pi) # (b, t, 1)
|
75 |
+
relf0 = omega - omega.narrow(1,0,1) # (b, t, 1)
|
76 |
+
|
77 |
+
in_coef, in_freq = self.mode_estimator(initial, space, kappa, gamma.narrow(1,9,1))
|
78 |
+
mode_coef = in_coef if mode_coef is None else mode_coef
|
79 |
+
mode_freq = in_freq if mode_freq is None else mode_freq
|
80 |
+
mode_freq = mode_freq + relf0 # linear FM
|
81 |
+
|
82 |
+
Nt = times.size(1) # total number of samples
|
83 |
+
Nf = mode_freq.size(1) # total number of frames
|
84 |
+
frames = self.get_frame_time(times, Nf)
|
85 |
+
|
86 |
+
space = space.repeat(1,f_0.size(1),1) # (b, frames, 1)
|
87 |
+
alpha = alpha.repeat(1,f_0.size(1),1) # (b, frames, 1)
|
88 |
+
kappa = kappa.repeat(1,f_0.size(1),1) # (b, frames, 1)
|
89 |
+
sigma = audio.T60_to_sigma(t60, f_0, 2*f_0*kappa) # (b, frames, 2)
|
90 |
+
|
91 |
+
# fourier features
|
92 |
+
feat = [space, frames, kappa, alpha, sigma, gamma]
|
93 |
+
feat = self.normalize_params(feat)
|
94 |
+
feat = self.material_encoder(feat) # (b, frames, n_feats * embed_dim)
|
95 |
+
|
96 |
+
damping = torch.exp(- frames * sigma.narrow(-1,0,1))
|
97 |
+
mode_coef = mode_coef * damping
|
98 |
+
ut, ut_freq, ut_coef = self.net(feat, mode_freq, mode_coef, frames, alpha, omega, Nt)
|
99 |
+
return ut, [in_freq, in_coef], [ut_freq, ut_coef]
|
100 |
+
|
101 |
+
def get_frame_time(self, times, Nf):
|
102 |
+
t_0 = times.narrow(1,0,1) # (Bs, 1, 1)
|
103 |
+
t_k = torch.ones_like(t_0).repeat(1,Nf,1).cumsum(1) / self.sr
|
104 |
+
t_k = t_k + t_0 # (Bs, Nt, 1)
|
105 |
+
return t_k
|
106 |
+
|
107 |
+
def normalize_params(self, params):
|
108 |
+
def rescale(var, scale):
|
109 |
+
minval = min(scale)
|
110 |
+
denval = max(scale) - minval
|
111 |
+
return (var - minval) / denval
|
112 |
+
space, times, kappa, alpha, sigma, gamma = params
|
113 |
+
sig_0, sig_1 = sigma.chunk(2, -1)
|
114 |
+
space = rescale(space, self.x_scale)
|
115 |
+
times = rescale(times - max(self.t_scale), self.t_scale)
|
116 |
+
kappa = rescale(kappa, self.kappa_scale)
|
117 |
+
alpha = rescale(alpha, self.alpha_scale)
|
118 |
+
sig_0 = rescale(sig_0, self.sig_0_scale)
|
119 |
+
sig_1 = rescale(sig_1, self.sig_1_scale)
|
120 |
+
gamma = rescale(gamma, self.gamma_scale)
|
121 |
+
sigma = torch.cat((sig_0, sig_1), dim=-1)
|
122 |
+
return torch.cat([space, times, kappa, alpha, sigma, gamma], dim=-1)
|
123 |
+
|
124 |
+
|
125 |
+
|
src/utils/audio.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import librosa
|
6 |
+
import soundfile as sf
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
eps = np.finfo(np.float32).eps
|
10 |
+
|
11 |
+
def calculate_rms(amp):
|
12 |
+
if isinstance(amp, torch.Tensor):
|
13 |
+
return amp.pow(2).mean(-1, keepdim=True).pow(.5)
|
14 |
+
elif isinstance(amp, np.ndarray):
|
15 |
+
return np.sqrt(np.mean(np.square(amp), axis=-1) + eps)
|
16 |
+
else:
|
17 |
+
raise TypeError(f"argument 'amp' must be torch.Tensor or np.ndarray. got: {type(amp)}")
|
18 |
+
|
19 |
+
def dB2amp(dB):
|
20 |
+
return np.power(10., dB/20.)
|
21 |
+
|
22 |
+
def amp2dB(amp):
|
23 |
+
return 20. * np.log10(amp)
|
24 |
+
|
25 |
+
def rms_normalize(wav, ref_dBFS=-23.0, skip_nan=True):
|
26 |
+
exists_nan = np.isnan(np.sum(wav))
|
27 |
+
if not skip_nan:
|
28 |
+
assert not exists_nan, np.isnan(wav)
|
29 |
+
if exists_nan:
|
30 |
+
return wav, 1.
|
31 |
+
# RMS normalize
|
32 |
+
# value_dBFS = 20*log10(rms(signal) * sqrt(2)) = 20*log10(rms(signal)) + 3.0103
|
33 |
+
rms = calculate_rms(wav)
|
34 |
+
if isinstance(ref_dBFS, torch.Tensor):
|
35 |
+
ref_linear = torch.pow(10, (ref_dBFS-3.0103)/20.)
|
36 |
+
else:
|
37 |
+
ref_linear = np.power(10, (ref_dBFS-3.0103)/20.)
|
38 |
+
gain = ref_linear / (rms + eps)
|
39 |
+
wav = gain * wav
|
40 |
+
return wav, gain
|
41 |
+
|
42 |
+
def ell_infty_normalize(wav, skip_nan=True):
|
43 |
+
if isinstance(wav, np.ndarray):
|
44 |
+
''' numpy '''
|
45 |
+
exists_nan = np.isnan(np.sum(wav))
|
46 |
+
if not skip_nan:
|
47 |
+
assert not exists_nan, np.isnan(wav)
|
48 |
+
if exists_nan:
|
49 |
+
return wav, 1.
|
50 |
+
maxv = np.max(np.abs(wav), axis=-1)
|
51 |
+
# 1 if maxv == 0 else 1. / maxv
|
52 |
+
if len(list(maxv.shape)) == 0:
|
53 |
+
gain = 1 if maxv==0 else 1. / maxv
|
54 |
+
else:
|
55 |
+
gain = 1. / maxv; gain[maxv==0] = 1
|
56 |
+
elif isinstance(wav, torch.Tensor):
|
57 |
+
''' torch '''
|
58 |
+
exists_nan = torch.isnan(wav.sum())
|
59 |
+
if not skip_nan:
|
60 |
+
assert not exists_nan, torch.isnan(wav)
|
61 |
+
if exists_nan:
|
62 |
+
return wav, 1.
|
63 |
+
maxv = wav.abs().max(-1).values.unsqueeze(-1)
|
64 |
+
# 1 if maxv == 0 else 1. / maxv
|
65 |
+
gain = torch.where(maxv.eq(0),
|
66 |
+
torch.ones_like(maxv), 1. / maxv)
|
67 |
+
else:
|
68 |
+
assert False, wav
|
69 |
+
wav = gain * wav
|
70 |
+
return wav, gain
|
71 |
+
|
72 |
+
def dB_RMS(wav):
|
73 |
+
if isinstance(wav, torch.Tensor):
|
74 |
+
return 20 * torch.log10(calculate_rms(wav))
|
75 |
+
elif isinstance(wav, np.ndarray):
|
76 |
+
return 20 * np.log10(calculate_rms(wav))
|
77 |
+
|
78 |
+
def mel_basis(sr, n_fft, n_mel):
|
79 |
+
return librosa.filters.mel(sr=sr,n_fft=n_fft,n_mels=n_mel,fmin=0,fmax=sr//2,norm=1)
|
80 |
+
|
81 |
+
def inv_mel_basis(sr, n_fft, n_mel):
|
82 |
+
return librosa.filters.mel(
|
83 |
+
sr=sr, n_fft=n_fft, n_mels=n_mel, norm=None, fmin=0, fmax=sr//2,
|
84 |
+
).T
|
85 |
+
|
86 |
+
def lin_to_mel(linspec, sr, n_fft, n_mel=80):
|
87 |
+
basis = mel_basis(sr, n_fft, n_mel)
|
88 |
+
return basis @ linspec
|
89 |
+
|
90 |
+
def save_waves(est, save_dir, sr=16000):
|
91 |
+
data = []
|
92 |
+
batch_size = inp.shape[0]
|
93 |
+
for b in range(batch_size):
|
94 |
+
est_wav = est[b,0].squeeze()
|
95 |
+
wave_path = f"{save_dir}/{b}.wav"
|
96 |
+
sf.write(wave_path, est_wav, samplerate=sr)
|
97 |
+
|
98 |
+
def get_inverse_window(forward_window, frame_length, frame_step):
|
99 |
+
denom = torch.square(forward_window)
|
100 |
+
overlaps = -(-frame_length // frame_step) # Ceiling division.
|
101 |
+
denom = F.pad(denom, (0, overlaps * frame_step - frame_length))
|
102 |
+
denom = denom.reshape(overlaps, frame_step)
|
103 |
+
denom = denom.sum(0, keepdims=True)
|
104 |
+
denom = denom.tile(overlaps, 1)
|
105 |
+
denom = denom.reshape(overlaps * frame_step)
|
106 |
+
return forward_window / denom[:frame_length]
|
107 |
+
|
108 |
+
def state_to_wav(state, normalize=True, sr=48000):
|
109 |
+
''' state: (Bs, Nt, Nx) '''
|
110 |
+
assert len(list(state.shape)) == 3, state.shape
|
111 |
+
Nt = state.size(1)
|
112 |
+
vel = ((state.narrow(1,1,Nt-1) - state.narrow(1,0,Nt-1)) * sr).sum(-1)
|
113 |
+
return ell_infty_normalize(vel)[0] if normalize else vel
|
114 |
+
|
115 |
+
def state_to_spec(x, window):
|
116 |
+
''' x: (Bs, Nt, Nx, Ch)
|
117 |
+
-> (Bs, Nt, Nx, Ch*n_fft*2)
|
118 |
+
'''
|
119 |
+
Bs, Nt, Nx, Ch = x.shape
|
120 |
+
n_ffts = window.size(-1)
|
121 |
+
n_freq = n_ffts // 2 + 1
|
122 |
+
hop_length = n_ffts // 4
|
123 |
+
x = rearrange(x, 'b t x c -> (b x c) t')
|
124 |
+
s = torch.stft(x, n_ffts, hop_length=hop_length, window=window)
|
125 |
+
s = rearrange(s, '(b x c) f t k -> b t x (c f k)',
|
126 |
+
b=Bs, x=Nx, c=Ch, f=n_freq, k=2)
|
127 |
+
return s
|
128 |
+
|
129 |
+
def spec_to_state(x, window, length):
|
130 |
+
''' x: (Bs, Nt, Nx, Ch*n_fft*2)
|
131 |
+
-> (Bs, Nt, Nx, Ch)
|
132 |
+
'''
|
133 |
+
Bs, Nt, Nx, _ = x.shape
|
134 |
+
n_ffts = window.size(-1)
|
135 |
+
n_freq = n_ffts // 2 + 1
|
136 |
+
|
137 |
+
x = rearrange(x, 'b t x (c f k) -> (b x c) f t k', f=n_freq, k=2)
|
138 |
+
x = torch.istft(x, n_ffts, length=length, window=window)
|
139 |
+
x = rearrange(x, '(b x c) t -> b t x c', b=Bs, x=Nx)
|
140 |
+
return x
|
141 |
+
|
142 |
+
|
143 |
+
def to_spec(x, window, reduce_channel=True):
|
144 |
+
''' x: (Bs, Nt)
|
145 |
+
-> (Bs, Nt, Nf*2) if reduce_channel==True
|
146 |
+
-> (Bs, Nt, Nf,2) otherwise
|
147 |
+
'''
|
148 |
+
Bs, Nt = x.shape
|
149 |
+
n_ffts = window.size(-1)
|
150 |
+
n_freq = n_ffts // 2 + 1
|
151 |
+
hop_length = n_ffts // 4
|
152 |
+
s = torch.stft(x, n_ffts, hop_length=hop_length, window=window)
|
153 |
+
s = s.transpose(1,2)
|
154 |
+
if reduce_channel:
|
155 |
+
s = rearrange(s, 'b t f k -> b t (f k)',
|
156 |
+
b=Bs, f=n_freq, k=2)
|
157 |
+
return s
|
158 |
+
|
159 |
+
def from_spec(x, window, length):
|
160 |
+
''' x: (Bs, Nt, Nf*2)
|
161 |
+
-> (Bs, Nt)
|
162 |
+
'''
|
163 |
+
Bs, Nt, _ = x.shape
|
164 |
+
n_ffts = window.size(-1)
|
165 |
+
n_freq = n_ffts // 2 + 1
|
166 |
+
|
167 |
+
x = rearrange(x, 'b t (f k) -> b f t k', f=n_freq, k=2)
|
168 |
+
x = torch.istft(x, n_ffts, length=length, window=window)
|
169 |
+
return x
|
170 |
+
|
171 |
+
def adjust_gain(y, x, minmax, ref_dBFS=-23.0):
|
172 |
+
ran_gain = (minmax[1] - minmax[0]) * torch.rand_like(y.narrow(-1,0,1)) + minmax[0]
|
173 |
+
ref_linear = np.power(10, (ref_dBFS-3.0103)/20.)
|
174 |
+
ran_linear = torch.pow(10, (ran_gain-3.0103)/20.)
|
175 |
+
x_rms = calculate_rms(x)
|
176 |
+
y_rms = calculate_rms(y)
|
177 |
+
x_gain = ref_linear / (x_rms + eps)
|
178 |
+
y_gain = ref_linear / (y_rms + eps)
|
179 |
+
|
180 |
+
y_xscale = y * y_gain / x_gain
|
181 |
+
return y_xscale / ran_linear
|
182 |
+
|
183 |
+
def degrade(x, rir, noise):
|
184 |
+
''' x : (Bs, Nt)
|
185 |
+
rir : (Bs, Nt)
|
186 |
+
noise: (Bs, Nt)
|
187 |
+
'''
|
188 |
+
x_pad = F.pad(x, (0,rir.size(-1)))
|
189 |
+
w_pad = F.pad(rir, (0,rir.size(-1)))
|
190 |
+
x_fft = torch.fft.rfft(x_pad)
|
191 |
+
w_fft = torch.fft.rfft(w_pad)
|
192 |
+
wet_x = torch.fft.irfft(x_fft * w_fft).narrow(-1,0,x.size(-1))
|
193 |
+
|
194 |
+
y = adjust_gain(wet_x, x, [-0, 30]) # ser
|
195 |
+
n = adjust_gain(noise, y, [10, 30]) # snr
|
196 |
+
return y + n
|
197 |
+
|
198 |
+
def T60_to_sigma(T60, f_0, K):
|
199 |
+
''' T60 : (Bs, 2, 2) [[T60_freq_1, T60_1], [T60_freq_2, T60_2]]
|
200 |
+
f_0 : (Bs, Nt, 1) fundamental frequency
|
201 |
+
K : (Bs, Nt, 1) kappa (K == gamma * kappa_rel)
|
202 |
+
-> sig : (Bs, Nt, 2)
|
203 |
+
'''
|
204 |
+
gamma = f_0 * 2
|
205 |
+
freq1, time1 = T60.narrow(1,0,1).chunk(2,-1)
|
206 |
+
freq2, time2 = T60.narrow(1,1,1).chunk(2,-1)
|
207 |
+
|
208 |
+
zeta1 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq1).pow(2)).pow(.5)
|
209 |
+
zeta2 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq2).pow(2)).pow(.5)
|
210 |
+
sig0 = - zeta2 / time1 + zeta1 / time2
|
211 |
+
sig0 = 6 * math.log(10) * sig0 / (zeta1 - zeta2)
|
212 |
+
|
213 |
+
sig1 = 1 / time1 - 1 / time2
|
214 |
+
sig1 = 6 * math.log(10) * sig1 / (zeta1 - zeta2)
|
215 |
+
|
216 |
+
sig = torch.cat((sig0, sig1), dim=-1)
|
217 |
+
return sig
|
218 |
+
|
219 |
+
|
src/utils/control.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
def constant(f0, n, dtype=None):
|
6 |
+
''' f0 (batch_size,)
|
7 |
+
n (int)
|
8 |
+
'''
|
9 |
+
return f0.unsqueeze(-1) * torch.ones(1,n, dtype=dtype)
|
10 |
+
|
11 |
+
def linear(f1, f2, n):
|
12 |
+
''' f1 (batch_size,)
|
13 |
+
f2 (batch_size,)
|
14 |
+
n (int)
|
15 |
+
'''
|
16 |
+
out = torch.cat((f1.unsqueeze(-1),f2.unsqueeze(-1)), dim=-1) # (batch_size, 2)
|
17 |
+
out = F.interpolate(out.unsqueeze(1), size=n, mode='linear', align_corners=True).squeeze(1) # (batch_size, n)
|
18 |
+
return out
|
19 |
+
|
20 |
+
def glissando(f1, f2, n, mode='linear'):
|
21 |
+
if mode == 'linear':
|
22 |
+
return linear(f1, f2, n)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(mode)
|
25 |
+
|
26 |
+
def vibrato(f0, k, mf=[3,5], ma=0.05, ma_in_hz=False):
|
27 |
+
''' f0 (batch_size, n)
|
28 |
+
k (int): 1/sr
|
29 |
+
mf (list): modulation frequency ([min, max])
|
30 |
+
ma (float): modulation amplitude (in Hz)
|
31 |
+
ma_in_hz (bool): ma is given in Hz (else: ma is given as a weighting factor of f0)
|
32 |
+
'''
|
33 |
+
ff = f0.narrow(-1,0,1)
|
34 |
+
def get_new_vibrato(f0, k, mf, ma, ma_in_hz):
|
35 |
+
mod_frq = mf[1] * torch.rand_like(ff) + mf[0] # (B, 1)
|
36 |
+
mod_amp = ma * torch.rand_like(ff) # (B, 1)
|
37 |
+
|
38 |
+
nt = f0.size(-1) # total time
|
39 |
+
vt = torch.floor((nt // 2) * torch.rand(f0.size(0)).view(-1,1)) # vibrato time
|
40 |
+
t = torch.ones_like(f0).cumsum(-1)
|
41 |
+
m = t.gt(vt) # mask `t` for n <= vt
|
42 |
+
vibra = m * mod_amp * (1 - torch.cos(2 * np.pi * mod_frq * (t - vt) * k)) / 2
|
43 |
+
if not ma_in_hz: vibra *= f0
|
44 |
+
return vibra * torch.randn_like(ff).sign()
|
45 |
+
return f0 + get_new_vibrato(f0, k, mf, ma, ma_in_hz)
|
46 |
+
|
47 |
+
def triangle_with_velocity(vel, n, sr_t, sr_x, max_u=.1):
|
48 |
+
''' vel (batch_size,) velocity
|
49 |
+
n (int) number of samples
|
50 |
+
sr_t (int) sampling rate in time
|
51 |
+
sr_x (int) sampling rate in space
|
52 |
+
max_u (float) maximum displacement
|
53 |
+
'''
|
54 |
+
vel = vel.view(-1,1) * sr_x / sr_t # m/s to non-dimensional quantity
|
55 |
+
vel = vel * torch.ones_like(vel).repeat(1,n)
|
56 |
+
u_H = torch.relu(max_u - (max_u - vel.cumsum(1)).abs() - vel)
|
57 |
+
u_H = u_H.pow(5).clamp(max=0.01)
|
58 |
+
return u_H
|
59 |
+
|
60 |
+
|
61 |
+
|
src/utils/ddsp.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.fft as fft
|
4 |
+
import numpy as np
|
5 |
+
import librosa as li
|
6 |
+
import math
|
7 |
+
|
8 |
+
def safe_log(x):
|
9 |
+
return torch.log(x + 1e-7)
|
10 |
+
|
11 |
+
@torch.no_grad()
|
12 |
+
def mean_std_loudness(dataset):
|
13 |
+
mean = 0
|
14 |
+
std = 0
|
15 |
+
n = 0
|
16 |
+
for _, _, l in dataset:
|
17 |
+
n += 1
|
18 |
+
mean += (l.mean().item() - mean) / n
|
19 |
+
std += (l.std().item() - std) / n
|
20 |
+
return mean, std
|
21 |
+
|
22 |
+
|
23 |
+
def multiscale_fft(signal, scales, overlap):
|
24 |
+
stfts = []
|
25 |
+
for s in scales:
|
26 |
+
S = torch.stft(
|
27 |
+
signal,
|
28 |
+
s,
|
29 |
+
int(s * (1 - overlap)),
|
30 |
+
s,
|
31 |
+
torch.hann_window(s).to(signal),
|
32 |
+
True,
|
33 |
+
normalized=True,
|
34 |
+
return_complex=True,
|
35 |
+
).abs()
|
36 |
+
stfts.append(S)
|
37 |
+
return stfts
|
38 |
+
|
39 |
+
|
40 |
+
def resample(x, factor: int):
|
41 |
+
batch, frame, channel = x.shape
|
42 |
+
x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame)
|
43 |
+
|
44 |
+
window = torch.hann_window(
|
45 |
+
factor * 2,
|
46 |
+
dtype=x.dtype,
|
47 |
+
device=x.device,
|
48 |
+
).reshape(1, 1, -1)
|
49 |
+
y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x)
|
50 |
+
y[..., ::factor] = x
|
51 |
+
y[..., -1:] = x[..., -1:]
|
52 |
+
y = torch.nn.functional.pad(y, [factor, factor])
|
53 |
+
y = torch.nn.functional.conv1d(y, window)[..., :-1]
|
54 |
+
|
55 |
+
y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1)
|
56 |
+
|
57 |
+
return y
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
def upsample(signal, factor):
|
62 |
+
signal = signal.permute(0,2,1)
|
63 |
+
signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear')
|
64 |
+
return signal.permute(0,2,1)
|
65 |
+
|
66 |
+
|
67 |
+
def remove_above_nyquist(amplitudes, pitch, sampling_rate):
|
68 |
+
''' amplitudes: (batch, frames, n_harmoincs)
|
69 |
+
pitch: (batch, frames, 1)
|
70 |
+
'''
|
71 |
+
n_harm = amplitudes.shape[-1]
|
72 |
+
pitches = pitch.repeat(1,1,n_harm).cumsum(-1)
|
73 |
+
aa = (pitches < sampling_rate / 2).float() + 1e-4
|
74 |
+
return amplitudes * aa
|
75 |
+
|
76 |
+
|
77 |
+
def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate):
|
78 |
+
''' amplitudes: (batch, frames, n_harmoincs)
|
79 |
+
frequencies: (batch, frames, n_harmonics)
|
80 |
+
'''
|
81 |
+
aa = (frequencies < sampling_rate / 2).float() + 1e-4
|
82 |
+
return amplitudes * aa
|
83 |
+
|
84 |
+
def scale_function(x):
|
85 |
+
''' 0 ~ 2'''
|
86 |
+
return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7
|
87 |
+
|
88 |
+
def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
|
89 |
+
S = li.stft(
|
90 |
+
signal,
|
91 |
+
n_fft=n_fft,
|
92 |
+
hop_length=block_size,
|
93 |
+
win_length=n_fft,
|
94 |
+
center=True,
|
95 |
+
)
|
96 |
+
S = np.log(abs(S) + 1e-7)
|
97 |
+
f = li.fft_frequencies(sampling_rate, n_fft)
|
98 |
+
a_weight = li.A_weighting(f)
|
99 |
+
|
100 |
+
S = S + a_weight.reshape(-1, 1)
|
101 |
+
|
102 |
+
S = np.mean(S, 0)[..., :-1]
|
103 |
+
|
104 |
+
return S
|
105 |
+
|
106 |
+
|
107 |
+
def extract_pitch(signal, sampling_rate, block_size):
|
108 |
+
length = signal.shape[-1] // block_size
|
109 |
+
f0 = crepe.predict(
|
110 |
+
signal,
|
111 |
+
sampling_rate,
|
112 |
+
step_size=int(1000 * block_size / sampling_rate),
|
113 |
+
verbose=1,
|
114 |
+
center=True,
|
115 |
+
viterbi=True,
|
116 |
+
)
|
117 |
+
f0 = f0[1].reshape(-1)[:-1]
|
118 |
+
|
119 |
+
if f0.shape[-1] != length:
|
120 |
+
f0 = np.interp(
|
121 |
+
np.linspace(0, 1, length, endpoint=False),
|
122 |
+
np.linspace(0, 1, f0.shape[-1], endpoint=False),
|
123 |
+
f0,
|
124 |
+
)
|
125 |
+
|
126 |
+
return f0
|
127 |
+
|
128 |
+
|
129 |
+
def harmonic_synth(pitch, amplitudes, sampling_rate):
|
130 |
+
n_harmonic = amplitudes.shape[-1]
|
131 |
+
omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1)
|
132 |
+
omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
|
133 |
+
signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True)
|
134 |
+
return signal
|
135 |
+
|
136 |
+
def modal_synth(modes, amplitude, sampling_rate, n_chunks=16):
|
137 |
+
freqs = modes.chunk(n_chunks, 1)
|
138 |
+
coefs = amplitude.chunk(n_chunks, 1)
|
139 |
+
lastf = torch.zeros_like(freqs[0])
|
140 |
+
sols = []
|
141 |
+
for f, c in zip(freqs, coefs):
|
142 |
+
fcs = f.cumsum(1) + lastf
|
143 |
+
sol = (torch.cos(fcs) * c).sum(-1, keepdim=True)
|
144 |
+
lastf = fcs.narrow(1,-1,1)
|
145 |
+
sols.append(sol)
|
146 |
+
return torch.cat(sols, 1)
|
147 |
+
|
148 |
+
|
149 |
+
def amp_to_impulse_response(amp, target_size):
|
150 |
+
amp = torch.stack([amp, torch.zeros_like(amp)], -1)
|
151 |
+
amp = torch.view_as_complex(amp)
|
152 |
+
amp = fft.irfft(amp)
|
153 |
+
|
154 |
+
filter_size = amp.shape[-1]
|
155 |
+
|
156 |
+
amp = torch.roll(amp, filter_size // 2, -1)
|
157 |
+
win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
|
158 |
+
|
159 |
+
amp = amp * win
|
160 |
+
|
161 |
+
amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
|
162 |
+
amp = torch.roll(amp, -filter_size // 2, -1)
|
163 |
+
|
164 |
+
return amp
|
165 |
+
|
166 |
+
|
167 |
+
def fft_convolve(signal, kernel):
|
168 |
+
signal = nn.functional.pad(signal, (0, signal.shape[-1]))
|
169 |
+
kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
|
170 |
+
|
171 |
+
output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
|
172 |
+
output = output[..., output.shape[-1] // 2:]
|
173 |
+
|
174 |
+
return output
|
175 |
+
|
src/utils/misc.py
ADDED
@@ -0,0 +1,336 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from scipy.interpolate import RectBivariateSpline
|
7 |
+
|
8 |
+
from contextlib import contextmanager,redirect_stderr,redirect_stdout
|
9 |
+
from os import devnull
|
10 |
+
|
11 |
+
chars = [c for c in '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
|
12 |
+
|
13 |
+
@contextmanager
|
14 |
+
def suppress_stdout_stderr():
|
15 |
+
"""A context manager that redirects stdout and stderr to devnull"""
|
16 |
+
with open(devnull, 'w') as fnull:
|
17 |
+
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
|
18 |
+
yield (err, out)
|
19 |
+
|
20 |
+
def batchify(x, batch_size, n_samples):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def random_str(length=8):
|
24 |
+
return "".join(np.random.choice(chars, length))
|
25 |
+
|
26 |
+
def sqrt(x):
|
27 |
+
return x.pow(.5) if isinstance(x, torch.Tensor) else x**.5
|
28 |
+
|
29 |
+
def soft_bow(v_rel, a=100):
|
30 |
+
return np.sqrt(2*a) * v_rel * torch.exp(-a * v_rel**2 + 0.5)
|
31 |
+
|
32 |
+
def hard_bow(v_rel, a=5, eps=0.1, hard_sign=True):
|
33 |
+
sign = torch.sign(v_rel) if hard_sign else torch.tanh(100 * v_rel)
|
34 |
+
return sign * (eps + (1-eps) * torch.exp(-a * v_rel.abs()))
|
35 |
+
|
36 |
+
def raised_cosine(N, h, ctr, wid, n):
|
37 |
+
''' N (int): number of maximal samples in space
|
38 |
+
h (float): spatial grid cell width
|
39 |
+
ctr (B,1,1): center points for each batch
|
40 |
+
wid (B,1,1): width lengths for each batch
|
41 |
+
n (B,): number of actual samples in space
|
42 |
+
'''
|
43 |
+
xax = torch.linspace(h, 1, N).to(ctr.device).view(1,-1,1) # (1, N, 1)
|
44 |
+
ctr = (ctr * n / N)
|
45 |
+
wid = wid / N
|
46 |
+
ind = torch.sign(torch.relu(-(xax - ctr - wid / 2) * (xax - ctr + wid / 2)))
|
47 |
+
out = 0.5 * ind * (1 + torch.cos(2 * np.pi * (xax - ctr) / wid))
|
48 |
+
return out / out.abs().sum(1, keepdim=True) # (batch_Size, N, 1)
|
49 |
+
|
50 |
+
def floor_dirac_delta(n, ctr, N):
|
51 |
+
''' torch::Tensor n, // number of samples in space
|
52 |
+
torch::Tensor ctr, // center point of raised cosine curve
|
53 |
+
int N
|
54 |
+
'''
|
55 |
+
xax = torch.ones_like(ctr).view(-1,1,1).repeat(1,N,1).cumsum(1) - 1
|
56 |
+
idx = torch.floor(ctr * n).view(-1,1,1)
|
57 |
+
#return torch.floor(xax).eq(idx).to(n.dtype()) # (batch_size, N, 1)
|
58 |
+
return torch.floor(xax).eq(idx) # (batch_size, N, 1)
|
59 |
+
|
60 |
+
def triangular(N, n, p_x, p_a):
|
61 |
+
''' N (int): number of maximal samples in space
|
62 |
+
n (B, 1, 1): number of actual samples in space
|
63 |
+
p_x (B, Nt, 1): peak position
|
64 |
+
p_a (B, Nt, 1): peak amplitude
|
65 |
+
'''
|
66 |
+
vel_l = torch.where(p_x.le(0), torch.zeros_like(p_x), p_a / p_x / n)
|
67 |
+
vel_r = torch.where(p_x.le(0), torch.zeros_like(p_x), p_a / (1-p_x) / n)
|
68 |
+
vel_l = ((vel_l * torch.ones_like(vel_l).repeat(1,1,N)).cumsum(2) - vel_l).clamp(min=0)
|
69 |
+
vel_r = ((vel_r * torch.ones_like(vel_r).repeat(1,1,N)).cumsum(2) - vel_r * (N-n+1)).clamp(min=0).flip(2)
|
70 |
+
tri = torch.minimum(vel_l, vel_r)
|
71 |
+
assert not torch.isnan(tri).any(), torch.isnan(tri.flatten(1).sum(1))
|
72 |
+
return tri
|
73 |
+
|
74 |
+
def pre_shaper(x, sr, velocity=10):
|
75 |
+
w = torch.tanh(torch.ones_like(x).cumsum(-1) / sr * velocity)
|
76 |
+
return w * x
|
77 |
+
|
78 |
+
def post_shaper(x, sr, pulloff, velocity=100):
|
79 |
+
offset = x.size(-1) - int(sr * pulloff)
|
80 |
+
w = torch.tanh(torch.ones_like(x).cumsum(-1) / sr * velocity).flip(-1)
|
81 |
+
w = F.pad(w.narrow(-1,offset,w.size(-1)-offset), (0,offset))
|
82 |
+
return w * x
|
83 |
+
|
84 |
+
def random_uniform(floor, ceiling, size=None, weight=None, dtype=None):
|
85 |
+
if not isinstance(size, tuple): size = (size,)
|
86 |
+
if weight is None: weight = torch.ones(size, dtype=dtype)
|
87 |
+
# NOTE: torch.rand(..., dtype=dtype) for dtype \in [torch.float32, torch.float64]
|
88 |
+
# can result in different random number generation
|
89 |
+
# (for different precisions; despite fixiing the random seed.)
|
90 |
+
return (ceiling - floor) * torch.rand(size=size).to(dtype) * weight + floor
|
91 |
+
|
92 |
+
def equidistant(floor, ceiling, steps, dtype=None):
|
93 |
+
return torch.linspace(floor, ceiling, steps).to(dtype)
|
94 |
+
|
95 |
+
def get_masks(model_name, bs, disjoint=True):
|
96 |
+
''' setting `disjoint=False` enables multiple excitations allowed
|
97 |
+
(e.g., bowing over hammered strings.) While this could be a
|
98 |
+
charming choice, but it can also drive the simulation unstable.
|
99 |
+
'''
|
100 |
+
# boolean mask that determines whether to impose each excitation
|
101 |
+
if model_name.endswith('bow'):
|
102 |
+
bow_mask = torch.ones( size=(bs,)).view(-1,1,1)
|
103 |
+
hammer_mask = torch.zeros(size=(bs,)).view(-1,1,1)
|
104 |
+
elif model_name.endswith('hammer'):
|
105 |
+
bow_mask = torch.zeros(size=(bs,)).view(-1,1,1)
|
106 |
+
hammer_mask = torch.ones( size=(bs,)).view(-1,1,1)
|
107 |
+
elif model_name.endswith('pluck'):
|
108 |
+
bow_mask = torch.zeros(size=(bs,)).view(-1,1,1)
|
109 |
+
hammer_mask = torch.zeros(size=(bs,)).view(-1,1,1)
|
110 |
+
else:
|
111 |
+
bow_mask = torch.rand(size=(bs,)).gt(0.5).view(-1,1,1)
|
112 |
+
hammer_mask = torch.rand(size=(bs,)).gt(0.5).view(-1,1,1)
|
113 |
+
if disjoint:
|
114 |
+
both_are_true = torch.logical_and(
|
115 |
+
torch.logical_or(bow_mask, hammer_mask),
|
116 |
+
torch.logical_or(bow_mask, hammer_mask.logical_not())
|
117 |
+
)
|
118 |
+
hammer_mask[both_are_true] = False
|
119 |
+
bow_mask = bow_mask.view(-1,1,1)
|
120 |
+
hammer_mask = hammer_mask.view(-1,1,1)
|
121 |
+
return [bow_mask, hammer_mask]
|
122 |
+
|
123 |
+
def f0_interpolate(f0_1, n_frames, tmax):
|
124 |
+
t_0 = np.linspace(0, tmax, n_frames)
|
125 |
+
t_1 = np.linspace(0, tmax, f0_1.shape[0])
|
126 |
+
return np.interp(t_0, t_1, f0_1)
|
127 |
+
|
128 |
+
def interpolate1d(u, xaxis, xvals, k=5):
|
129 |
+
''' u: (1, Nx)
|
130 |
+
xaxis: (1, Nx_input)
|
131 |
+
xvals: (1, Nx_output)
|
132 |
+
-> (1, Nx_output)
|
133 |
+
'''
|
134 |
+
t = np.arange(k)[:,None] / k
|
135 |
+
rbs = RectBivariateSpline(t, xaxis, u.repeat(k,0), kx=1, ky=k)
|
136 |
+
return rbs(t, xvals, grid=True)[k//2][None,:]
|
137 |
+
|
138 |
+
def interpolate(u, taxis, xaxis, xvals, kx=5, ky=5):
|
139 |
+
''' u: (Nt, Nx)
|
140 |
+
taxis: (Nt, 1)
|
141 |
+
xaxis: (1, Nx_input)
|
142 |
+
xvals: (1, Nx_output)
|
143 |
+
-> (Nt, Nx_output)
|
144 |
+
'''
|
145 |
+
rbs = RectBivariateSpline(taxis, xaxis, u, kx=kx, ky=ky)
|
146 |
+
return rbs(taxis, xvals, grid=True)
|
147 |
+
|
148 |
+
def torch_interpolate(x, scale_factor):
|
149 |
+
y = F.interpolate(x, scale_factor=scale_factor)
|
150 |
+
res = x.size(-1) - y.size(-1)
|
151 |
+
if res % 2 == 0: y = F.pad(y, (res//2, res//2))
|
152 |
+
else: y = F.pad(y, (res//2, res//2+1))
|
153 |
+
return y
|
154 |
+
|
155 |
+
|
156 |
+
def minmax_normalize(x, dim=-1):
|
157 |
+
x_min = x.min(dim, keepdim=True).values
|
158 |
+
x = x - x_min
|
159 |
+
x_max = x.max(dim, keepdim=True).values
|
160 |
+
x = x / x_max
|
161 |
+
return x
|
162 |
+
|
163 |
+
def get_minmax(x):
|
164 |
+
if np.isnan(x.sum()):
|
165 |
+
return None, None
|
166 |
+
return np.nan_to_num(x.min()), np.nan_to_num(x.max())
|
167 |
+
|
168 |
+
def select_with_batched_index(input, dim, index):
|
169 |
+
''' input: (bs, ..., n, ...)
|
170 |
+
dim : (int)
|
171 |
+
index: (bs, ..., 1, ...) index to select on dim `dim`
|
172 |
+
-> out : (bs, ..., 1, ...) for each batch, select `index`-th element on dim `dim`
|
173 |
+
'''
|
174 |
+
assert input.size(0) == index.size(0), [input.shpae, index.shape]
|
175 |
+
bs = input.size(0)
|
176 |
+
ins = input.chunk(bs, 0)
|
177 |
+
idx = index.chunk(bs, 0)
|
178 |
+
out = []
|
179 |
+
for b in range(bs):
|
180 |
+
out.append(batched_index_select(ins[b], dim, idx[b]))
|
181 |
+
return torch.cat(out, dim=0)
|
182 |
+
|
183 |
+
def batched_index_select(input, dim, index):
|
184 |
+
''' input: (..., n, ...)
|
185 |
+
dim : (int)
|
186 |
+
index: (..., k, ...) index to select on dim `dim`
|
187 |
+
-> out : (..., k, ...) select k out of n elements on dim `dim`
|
188 |
+
'''
|
189 |
+
Nx = len(list(input.shape))
|
190 |
+
expanse = [-1 if k==(dim % Nx) else 1 for k in range(Nx)]
|
191 |
+
tiler = [1 if k==(dim % Nx) else n for k, n in enumerate(input.shape)]
|
192 |
+
index = index.to(torch.int64).view(expanse).tile(tiler)
|
193 |
+
return torch.gather(input, dim, index)
|
194 |
+
|
195 |
+
def random_index(max_N, idx_N):
|
196 |
+
if max_N < idx_N:
|
197 |
+
# choosing with replacement
|
198 |
+
return torch.randint(0, max_N, (idx_N,))
|
199 |
+
else:
|
200 |
+
# choosing without replacement
|
201 |
+
return torch.randperm(max_N)[:idx_N]
|
202 |
+
|
203 |
+
def ell_infty_normalize(x, normalize_dims=1):
|
204 |
+
eps = torch.finfo(x.dtype).eps
|
205 |
+
x_shape = list(x.shape)
|
206 |
+
m_shape = x_shape[:normalize_dims] + [1] * (len(x_shape) - normalize_dims)
|
207 |
+
x_max = x.abs().flatten(normalize_dims).max(normalize_dims).values + eps
|
208 |
+
x_gain = 1. / x_max.view(m_shape)
|
209 |
+
return x * x_gain, x_gain
|
210 |
+
|
211 |
+
def sinusoidal_embedding(x, n, gain=10000, dim=-1):
|
212 |
+
''' let `x` be normalized to be in the nondimensional (0 ~ 1) range '''
|
213 |
+
assert n % 2 == 0, n
|
214 |
+
x = x.unsqueeze(-1)
|
215 |
+
shape = [1] * len(list(x.shape)); shape[dim] = -1 # e.g., [1,1,-1]
|
216 |
+
half_n = n // 2
|
217 |
+
|
218 |
+
expnt = torch.arange(half_n, device=x.device, dtype=x.dtype).view(shape)
|
219 |
+
_embed = torch.exp(expnt * -(np.log(gain) / (half_n - 1)))
|
220 |
+
_embed = torch.exp(expnt * -(np.log(gain) / (half_n - 1)))
|
221 |
+
_embed = x * _embed
|
222 |
+
emb = torch.cat((torch.sin(_embed), torch.cos(_embed)), dim)
|
223 |
+
return emb # list(x.shape) + [n]
|
224 |
+
|
225 |
+
def fourier_feature(x, B):
|
226 |
+
''' x: (Bs, ..., in_dim)
|
227 |
+
B: (in_dim, out_dim)
|
228 |
+
'''
|
229 |
+
if B is None:
|
230 |
+
return x
|
231 |
+
else:
|
232 |
+
x_proj = (2.*np.pi*x) @ B
|
233 |
+
return torch.cat((torch.sin(x_proj), torch.cos(x_proj)), dim=-1)
|
234 |
+
|
235 |
+
def save_simulation_data(directory, excitation_type, overall_results, constants):
|
236 |
+
os.makedirs(directory, exist_ok=True)
|
237 |
+
string_params = overall_results.pop('string_params')
|
238 |
+
hammer_params = overall_results.pop('hammer_params')
|
239 |
+
bow_params = overall_results.pop('bow_params')
|
240 |
+
simulation_dict = overall_results
|
241 |
+
string_dict = {
|
242 |
+
'kappa': string_params[0],
|
243 |
+
'alpha': string_params[1],
|
244 |
+
'u0' : string_params[2],
|
245 |
+
'v0' : string_params[3],
|
246 |
+
'f0' : string_params[4],
|
247 |
+
'pos' : string_params[5],
|
248 |
+
'T60' : string_params[6],
|
249 |
+
'target_f0': string_params[7],
|
250 |
+
}
|
251 |
+
hammer_dict = {
|
252 |
+
'x_H' : hammer_params[0],
|
253 |
+
'v_H' : hammer_params[1],
|
254 |
+
'u_H' : hammer_params[2],
|
255 |
+
'w_H' : hammer_params[3],
|
256 |
+
'M_r' : hammer_params[4],
|
257 |
+
'alpha': hammer_params[5],
|
258 |
+
}
|
259 |
+
bow_dict = {
|
260 |
+
'x_B' : bow_params[0],
|
261 |
+
'v_B' : bow_params[1],
|
262 |
+
'F_B' : bow_params[2],
|
263 |
+
'phi_0': bow_params[3],
|
264 |
+
'phi_1': bow_params[4],
|
265 |
+
'wid_B': bow_params[5],
|
266 |
+
}
|
267 |
+
|
268 |
+
def sample(val):
|
269 |
+
try:
|
270 |
+
_val = val.item(0)
|
271 |
+
except AttributeError as err:
|
272 |
+
if isinstance(val, float) or isinstance(val, int):
|
273 |
+
_val = val
|
274 |
+
else:
|
275 |
+
raise err
|
276 |
+
return _val
|
277 |
+
short_configuration = {
|
278 |
+
'excitation_type': excitation_type,
|
279 |
+
'theta_t' : constants[1],
|
280 |
+
'lambda_c': constants[2],
|
281 |
+
}
|
282 |
+
short_configuration['value-string'] = {}
|
283 |
+
for key, val in string_dict.items():
|
284 |
+
short_configuration['value-string'].update({ key : sample(val) })
|
285 |
+
short_configuration['value-hammer'] = {}
|
286 |
+
for key, val in hammer_dict.items():
|
287 |
+
short_configuration['value-hammer'].update({ key : sample(val) })
|
288 |
+
short_configuration['value-bow'] = {}
|
289 |
+
for key, val in bow_dict.items():
|
290 |
+
short_configuration['value-bow'].update({ key : sample(val) })
|
291 |
+
|
292 |
+
np.savez_compressed(f'{directory}/simulation.npz', **simulation_dict)
|
293 |
+
np.savez_compressed(f'{directory}/string_params.npz', **string_dict)
|
294 |
+
np.savez_compressed(f'{directory}/hammer_params.npz', **hammer_dict)
|
295 |
+
np.savez_compressed(f'{directory}/bow_params.npz', **bow_dict)
|
296 |
+
|
297 |
+
with open(f"{directory}/simulation_config.yaml", 'w') as f:
|
298 |
+
yaml.dump(short_configuration, f, default_flow_style=False)
|
299 |
+
|
300 |
+
def add_noise(x, c, vals, eps=1e-5):
|
301 |
+
noise = eps * torch.randn_like(x)
|
302 |
+
for val in vals:
|
303 |
+
mask = torch.where(c == val, torch.ones_like(c), torch.zeros_like(c))
|
304 |
+
x = x + mask * noise
|
305 |
+
return x
|
306 |
+
|
307 |
+
def downsample(x, factor=None, size=None):
|
308 |
+
''' x: (Bs, Nt) -> (Bs, Nt // factor)
|
309 |
+
'''
|
310 |
+
if size is None:
|
311 |
+
size = x.size(1) // factor + bool(x.size(1) % factor)
|
312 |
+
else:
|
313 |
+
assert factor is None, [factor, size]
|
314 |
+
return F.interpolate(x.unsqueeze(1), size=size, mode='linear').squeeze(1)
|
315 |
+
|
316 |
+
|
317 |
+
if __name__=='__main__':
|
318 |
+
N = 10
|
319 |
+
B = 1
|
320 |
+
h = 1 / N
|
321 |
+
ctr = 0.5 * torch.ones(B).view(-1,1,1)
|
322 |
+
wid = 1 * torch.ones(B).view(-1,1,1)
|
323 |
+
n = N * torch.ones(B)
|
324 |
+
''' N (int): number of maximal samples in space
|
325 |
+
h (float): spatial grid cell width
|
326 |
+
ctr (B,1,1): center points for each batch
|
327 |
+
wid (B,1,1): width lengths for each batch
|
328 |
+
n (B,): number of actual samples in space
|
329 |
+
'''
|
330 |
+
c = raised_cosine(N, h, ctr, wid, n)
|
331 |
+
print(c.shape)
|
332 |
+
import matplotlib.pyplot as plt
|
333 |
+
plt.figure()
|
334 |
+
plt.plot(c[0,:,0])
|
335 |
+
plt.savefig('asdf.png')
|
336 |
+
|
src/utils/plot.py
ADDED
@@ -0,0 +1,1132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import subprocess
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
import librosa
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import scipy
|
10 |
+
from src.utils.control import *
|
11 |
+
from src.utils.misc import soft_bow, hard_bow, sinusoidal_embedding
|
12 |
+
from src.utils.audio import rms_normalize
|
13 |
+
import soundfile as sf
|
14 |
+
|
15 |
+
plt.rc('text', usetex=True)
|
16 |
+
plt.rc('font', family='serif')
|
17 |
+
|
18 |
+
def gt_param(TF=5, sr=44100):
|
19 |
+
sr = 44100
|
20 |
+
NF = int(sr * TF)
|
21 |
+
k = 1 / sr
|
22 |
+
TRANS = int(0.05 * sr)
|
23 |
+
x_bow = torch.linspace(0.25, 0.45, NF)
|
24 |
+
v_bow = 0.1 * torch.tanh(torch.linspace(0., 10, NF))
|
25 |
+
F_bow = torch.cat((
|
26 |
+
torch.linspace(100, 120, NF//8 - TRANS), torch.zeros(TRANS),
|
27 |
+
100 * torch.ones(NF//8 - TRANS), torch.zeros(TRANS),
|
28 |
+
100 * torch.ones(NF//8 - TRANS), torch.zeros(TRANS),
|
29 |
+
torch.linspace(100, 80, NF//8 - TRANS), torch.zeros(TRANS),
|
30 |
+
80 * torch.ones(NF//4),
|
31 |
+
torch.zeros(NF//4),
|
32 |
+
), dim=-1)
|
33 |
+
|
34 |
+
f0 = torch.cat((
|
35 |
+
glissando(98,110, NF//8),
|
36 |
+
constant(130.81, NF//8),
|
37 |
+
glissando(146.83, 164.81, NF//8),
|
38 |
+
constant(207.65, NF//8),
|
39 |
+
vibrato(207.65, NF//4, k, 5, 10),
|
40 |
+
constant(207.65, NF//4),
|
41 |
+
), dim=-1)
|
42 |
+
F_bow = F.pad(F_bow, (NF-F_bow.size(-1),0))
|
43 |
+
f0 = F.pad(f0, (NF-f0.size(-1),0))
|
44 |
+
|
45 |
+
#wid = torch.linspace(0.05, 0.05, NF)
|
46 |
+
#rp = np.array([0.3, 0.7])
|
47 |
+
#T60 = np.array([[100, 8], [2000, 5]])
|
48 |
+
|
49 |
+
return [x_bow, v_bow, F_bow, f0]
|
50 |
+
|
51 |
+
def param(est_param, gt_param, save_path):
|
52 |
+
e_x_bow, e_v_bow, e_F_bow, e_f0 = [item.detach().cpu().numpy() for item in est_param[:4]]
|
53 |
+
g_x_bow, g_v_bow, g_F_bow, g_f0 = [item.cpu().numpy() for item in gt_param]
|
54 |
+
|
55 |
+
fig, ax = plt.subplots(figsize=(7,7), nrows=4, ncols=1)
|
56 |
+
|
57 |
+
ax[0].plot(g_x_bow, 'b:')
|
58 |
+
ax[0].plot(e_x_bow, 'k-')
|
59 |
+
ax[0].axhline(y=0, c='k', lw=.5)
|
60 |
+
ax[0].set_ylabel('bow pos')
|
61 |
+
|
62 |
+
ax[1].plot(g_v_bow, 'b:')
|
63 |
+
ax[1].plot(e_v_bow, 'k-')
|
64 |
+
ax[1].axhline(y=0, c='k', lw=.5)
|
65 |
+
ax[1].set_ylabel('bow vel')
|
66 |
+
|
67 |
+
ax[2].plot(g_F_bow, 'b:')
|
68 |
+
ax[2].plot(e_F_bow, 'k-')
|
69 |
+
ax[2].axhline(y=0, c='k', lw=.5)
|
70 |
+
ax[2].set_ylabel('bow force')
|
71 |
+
|
72 |
+
ax[3].plot(g_f0, 'b:')
|
73 |
+
ax[3].plot(e_f0, 'k-')
|
74 |
+
ax[3].axhline(y=0, c='k', lw=.5)
|
75 |
+
ax[3].set_ylabel('f0')
|
76 |
+
|
77 |
+
plt.tight_layout()
|
78 |
+
plt.savefig(save_path)
|
79 |
+
plt.clf()
|
80 |
+
plt.close()
|
81 |
+
|
82 |
+
|
83 |
+
def simulation_data(
|
84 |
+
save_dir,
|
85 |
+
uout, zout, v_r_out, F_H_out, u_H_out,
|
86 |
+
state_u, state_z,
|
87 |
+
string_params, bow_params, hammer_params,
|
88 |
+
**kwargs,
|
89 |
+
):
|
90 |
+
N = min(1000, uout.shape[0])
|
91 |
+
|
92 |
+
kappa, alpha, u0, v0, f0, pos, T60, target_f0 = string_params
|
93 |
+
x_b, v_b, F_b, phi_0, phi_1, wid_b = bow_params
|
94 |
+
x_H, v_H, u_H, w_H, M_r, alpha_H = hammer_params
|
95 |
+
|
96 |
+
max_disp = np.max(np.abs(uout[:N]))
|
97 |
+
rels = torch.linspace(-1,1,100)
|
98 |
+
prof = hard_bow(rels, phi_0, phi_1)
|
99 |
+
|
100 |
+
# plot string params
|
101 |
+
fig, ax = plt.subplots(figsize=(7,7), nrows=5, ncols=1)
|
102 |
+
|
103 |
+
ax[0].plot(f0, 'k-')
|
104 |
+
ax[0].axhline(y=0, c='k', lw=.5)
|
105 |
+
ax[0].set_ylabel('f0')
|
106 |
+
ax[0].yaxis.tick_right()
|
107 |
+
ax[0].set_ylim([0, 500])
|
108 |
+
|
109 |
+
ax[1].plot(np.linspace(0,1,state_u.shape[-1]), state_u[-1], 'k-')
|
110 |
+
ax[1].axvline(x=pos, c='r', lw=.5); ax[1].axvline(x=x_b[-1], c='b', lw=.5)
|
111 |
+
ax[1].set_ylabel('transverse state')
|
112 |
+
ax[1].yaxis.tick_right()
|
113 |
+
#ax[1].set_ylim([-max_disp, max_disp])
|
114 |
+
|
115 |
+
ax[2].plot(np.linspace(0,1,state_z.shape[-1]), state_z[-1], 'k-')
|
116 |
+
ax[2].axvline(x=pos, c='r', lw=.5); ax[1].axvline(x=x_b[-1], c='b', lw=.5)
|
117 |
+
ax[2].set_ylabel('longitudinal state')
|
118 |
+
ax[2].yaxis.tick_right()
|
119 |
+
#ax[2].set_ylim([-max_disp, max_disp])
|
120 |
+
|
121 |
+
ax[3].plot(np.arange(N), uout[:N], 'k-')
|
122 |
+
ax[3].axhline(y=0, c='k', lw=.5)
|
123 |
+
ax[3].set_ylabel('output')
|
124 |
+
ax[3].yaxis.tick_right()
|
125 |
+
ax[3].set_ylim([-max_disp, max_disp])
|
126 |
+
|
127 |
+
ax[4].plot(np.arange(N), zout[:N], 'k-')
|
128 |
+
ax[4].axhline(y=0, c='k', lw=.5)
|
129 |
+
ax[4].set_ylabel('output')
|
130 |
+
ax[4].yaxis.tick_right()
|
131 |
+
#ax[4].set_ylim([-max_disp, max_disp])
|
132 |
+
|
133 |
+
plt.tight_layout()
|
134 |
+
plt.savefig(f"{save_dir}/string.png")
|
135 |
+
plt.clf()
|
136 |
+
plt.close()
|
137 |
+
|
138 |
+
# plot bow params
|
139 |
+
fig, ax = plt.subplots(figsize=(7,7), nrows=3, ncols=2)
|
140 |
+
|
141 |
+
ax[0,0].plot(x_b, 'k-') ; ax[0,1].plot(rels.numpy(), prof.numpy(), 'k-')
|
142 |
+
ax[0,0].axhline(y=0, c='k', lw=.5) ; ax[0,1].axhline(y=0, c='k', lw=.5)
|
143 |
+
ax[0,0].set_ylabel('bowing position') ; ax[0,1].set_ylabel('bow friction fn')
|
144 |
+
ax[0,0].yaxis.tick_right() ; ax[0,1].yaxis.tick_right()
|
145 |
+
ax[0,0].set_ylim([0, 1]) ; ax[0,1].set_ylim([-1.5, 1.5])
|
146 |
+
|
147 |
+
ax[1,0].plot(v_b, 'k-') ; ax[1,1].plot(np.arange(N), v_r_out[:N], 'k-')
|
148 |
+
ax[1,0].axhline(y=0, c='k', lw=.5) ; ax[1,1].axhline(y=0, c='k', lw=.5)
|
149 |
+
ax[1,0].set_ylabel('bowing velocity') ; ax[1,1].set_ylabel('rel vel (attack)')
|
150 |
+
ax[1,0].yaxis.tick_right() ; ax[1,1].yaxis.tick_right()
|
151 |
+
ax[1,0].set_ylim([0, 0.5]) ; ax[1,1].set_ylim([-2, 2])
|
152 |
+
|
153 |
+
ax[2,0].plot(F_b, 'k-') ; ax[2,1].plot(np.arange(N), v_r_out[-N:], 'k-')
|
154 |
+
ax[2,0].axhline(y=0, c='k', lw=.5) ; ax[2,1].axhline(y=0, c='k', lw=.5)
|
155 |
+
ax[2,0].set_ylabel('bowing force') ; ax[2,1].set_ylabel('rel vel (release)')
|
156 |
+
ax[2,0].yaxis.tick_right() ; ax[2,1].yaxis.tick_right()
|
157 |
+
ax[2,0].set_ylim([0, 100]) ; ax[2,1].set_ylim([-2, 2])
|
158 |
+
|
159 |
+
plt.tight_layout()
|
160 |
+
plt.savefig(f"{save_dir}/bow.png")
|
161 |
+
plt.clf()
|
162 |
+
plt.close()
|
163 |
+
|
164 |
+
|
165 |
+
sr = 48000
|
166 |
+
Nt = len(v_r_out)
|
167 |
+
Nx = state_u.shape[-1]
|
168 |
+
a_f = (v_r_out[1:] - v_r_out[:Nt-1]) * sr
|
169 |
+
F_f = a_f / Nx
|
170 |
+
mu = F_f / F_b[-(Nt-1):]
|
171 |
+
vr = v_r_out[:Nt-1]
|
172 |
+
rels = torch.linspace(np.min(vr)-.1,np.max(vr)+.1,100)
|
173 |
+
prof = hard_bow(rels, phi_0, phi_1)
|
174 |
+
#prof = soft_bow(rels, phi_0)
|
175 |
+
|
176 |
+
fig, ax = plt.subplots(figsize=(4,4), nrows=1, ncols=1)
|
177 |
+
#ax.plot(rels.numpy(), prof.numpy(), 'r--')
|
178 |
+
ax.fill_between(rels.numpy(), prof.numpy(), alpha=0.2, facecolor='r')
|
179 |
+
ax.plot(vr, mu, 'k-')
|
180 |
+
ax.axhline(y=0, c='k', lw=.5)
|
181 |
+
ax.set_xlabel('Relative velocity')
|
182 |
+
ax.set_ylabel('Friction coefficient')
|
183 |
+
ax.set_ylim([-1.5, 1.5])
|
184 |
+
|
185 |
+
plt.tight_layout()
|
186 |
+
plt.savefig(f"{save_dir}/bow-velforce.pdf")
|
187 |
+
plt.clf()
|
188 |
+
plt.close()
|
189 |
+
|
190 |
+
# plot string params
|
191 |
+
fig, ax = plt.subplots(figsize=(7,7), nrows=2, ncols=1)
|
192 |
+
|
193 |
+
sr = 48000
|
194 |
+
# ms
|
195 |
+
t_1 = 0; Nt_1 = int(sr * t_1 * 1e-3)
|
196 |
+
#t_2 = 3; Nt_2 = int(sr * t_2 * 1e-3)
|
197 |
+
t_2 = 8; Nt_2 = int(sr * t_2 * 1e-3)
|
198 |
+
time = np.linspace(t_1, t_2, Nt_2 - Nt_1)
|
199 |
+
ax[0].plot(time, u_H_out[Nt_1:Nt_2], 'k-')
|
200 |
+
ax[0].axhline(y=0, c='k', lw=.5)
|
201 |
+
ax[0].set_ylabel('hammer displacement')
|
202 |
+
ax[0].yaxis.tick_right()
|
203 |
+
#ax[0].set_ylim([0, 0.1])
|
204 |
+
|
205 |
+
ax[1].plot(time, F_H_out[Nt_1:Nt_2], 'k-')
|
206 |
+
ax[1].axhline(y=0, c='k', lw=.5)
|
207 |
+
ax[1].set_ylabel('hammer force')
|
208 |
+
ax[1].yaxis.tick_right()
|
209 |
+
#ax[1].set_ylim([0, 10000])
|
210 |
+
|
211 |
+
plt.tight_layout()
|
212 |
+
plt.savefig(f"{save_dir}/hammer.png")
|
213 |
+
plt.clf()
|
214 |
+
plt.close()
|
215 |
+
|
216 |
+
|
217 |
+
def state_specs(save_path, analytic, estimate, simulate):
|
218 |
+
tf = 100
|
219 |
+
Nt, Nx = simulate.shape
|
220 |
+
nt = Nt // tf
|
221 |
+
nx = Nx // 2
|
222 |
+
diff_ana = analytic - simulate
|
223 |
+
diff_est = estimate - simulate
|
224 |
+
|
225 |
+
maxval = np.max(np.abs(simulate))
|
226 |
+
maxerr = max(np.max(np.abs(diff_ana)), np.max(np.abs(diff_est)))
|
227 |
+
|
228 |
+
nrows = 3; ncols = 2
|
229 |
+
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(7,7))
|
230 |
+
s_state = librosa.display.specshow(simulate[0::tf].T, cmap='coolwarm', ax=ax[0,0])
|
231 |
+
a_state = librosa.display.specshow(analytic[0::tf].T, cmap='coolwarm', ax=ax[1,0])
|
232 |
+
e_state = librosa.display.specshow(estimate[0::tf].T, cmap='coolwarm', ax=ax[2,0])
|
233 |
+
|
234 |
+
a_diffs = librosa.display.specshow(diff_ana[0::tf].T, cmap='coolwarm', ax=ax[1,1])
|
235 |
+
e_diffs = librosa.display.specshow(diff_est[0::tf].T, cmap='coolwarm', ax=ax[2,1])
|
236 |
+
|
237 |
+
ax[0,1].plot(simulate[:nt,nx], c='goldenrod', label='FDTD')
|
238 |
+
ax[0,1].plot(analytic[:nt,nx], c='r', label='Modal')
|
239 |
+
ax[0,1].plot(estimate[:nt,nx], c='g', label='Ours')
|
240 |
+
|
241 |
+
a_state.set_clim([-maxval, +maxval])
|
242 |
+
e_state.set_clim([-maxval, +maxval])
|
243 |
+
s_state.set_clim([-maxval, +maxval])
|
244 |
+
a_diffs.set_clim([-maxerr, +maxerr])
|
245 |
+
e_diffs.set_clim([-maxerr, +maxerr])
|
246 |
+
titles = ['FDTD', 'Modal', 'Ours']
|
247 |
+
for i, title in enumerate(titles):
|
248 |
+
ax[i,0].set_ylabel(title)
|
249 |
+
for i in range(nrows):
|
250 |
+
for j in range(ncols):
|
251 |
+
ax[i,j].set_xticks([])
|
252 |
+
ax[i,j].set_yticks([])
|
253 |
+
|
254 |
+
ax[0,1].legend(
|
255 |
+
loc='lower center', bbox_to_anchor=(.95,-0.5),
|
256 |
+
ncol=1, fancybox=True,
|
257 |
+
handlelength=1., handletextpad=0.1, columnspacing=.5, fontsize=7,
|
258 |
+
)
|
259 |
+
|
260 |
+
fig.tight_layout()
|
261 |
+
fig.subplots_adjust(wspace=0)
|
262 |
+
fig.subplots_adjust(hspace=0)
|
263 |
+
|
264 |
+
plt.savefig(save_path, bbox_inches='tight')
|
265 |
+
plt.close('all')
|
266 |
+
plt.clf()
|
267 |
+
|
268 |
+
def state_video(save_dir, state_u, sr, framerate=100, trim_front=True, verbose=False, prefix=None, fname='output', maxy=None):
|
269 |
+
if isinstance(state_u, list):
|
270 |
+
state_v = state_u[1]
|
271 |
+
state_u = state_u[0]
|
272 |
+
else:
|
273 |
+
state_v = None
|
274 |
+
|
275 |
+
if trim_front:
|
276 |
+
state_u = state_u[:int(sr / 55)] # for 55 Hz (A1)
|
277 |
+
state_v = state_v[:int(sr / 55)] if state_v is not None else None
|
278 |
+
downs = int(state_u.shape[0]/framerate)
|
279 |
+
else:
|
280 |
+
downs = 100
|
281 |
+
|
282 |
+
Nt, Nx = state_u.shape
|
283 |
+
maxy = np.max(np.abs(state_u)) if maxy is None else maxy
|
284 |
+
locs = np.linspace(0, 1, Nx)
|
285 |
+
for j in range(Nt // downs):
|
286 |
+
|
287 |
+
plt.figure(figsize=(5,2))
|
288 |
+
if state_v is not None:
|
289 |
+
plt.plot(locs, state_v[j * downs], c='k', alpha=0.5)
|
290 |
+
plt.plot(locs, state_u[j * downs], c='k')
|
291 |
+
plt.xlim([0, 1])
|
292 |
+
plt.ylim([-maxy, maxy])
|
293 |
+
plt.xticks([])
|
294 |
+
plt.yticks([])
|
295 |
+
|
296 |
+
plt.tight_layout()
|
297 |
+
os.makedirs(f'{save_dir}/temp', exist_ok=True)
|
298 |
+
plt.savefig(f'{save_dir}/temp/file%02d.png' % j)
|
299 |
+
plt.clf()
|
300 |
+
plt.close("all")
|
301 |
+
|
302 |
+
prefix = 'fdtd' if prefix is None else prefix
|
303 |
+
with open(os.devnull, 'w') as devnull:
|
304 |
+
silent_video = ['ffmpeg',
|
305 |
+
'-framerate', f'{framerate}',
|
306 |
+
'-i', f'{save_dir}/temp/file%02d.png',
|
307 |
+
'-r', '30', '-pix_fmt', 'yuv420p', '-y',
|
308 |
+
f'{save_dir}/{prefix}-{fname}-silent_video.mp4']
|
309 |
+
output_video = ['ffmpeg',
|
310 |
+
'-i', f'{save_dir}/{prefix}-{fname}-silent_video.mp4',
|
311 |
+
'-i', f'{save_dir}/{fname}.wav',
|
312 |
+
'-c:v', 'copy', '-map', '0:v', '-map', '1:a',
|
313 |
+
'-shortest', '-y',
|
314 |
+
f'{save_dir}/{prefix}-{fname}.mp4']
|
315 |
+
silent_video += ['-loglevel', 'quiet'] if not verbose else []
|
316 |
+
output_video += ['-loglevel', 'quiet'] if not verbose else []
|
317 |
+
subprocess.call(silent_video, stdout=devnull)
|
318 |
+
subprocess.call(output_video, stdout=devnull)
|
319 |
+
|
320 |
+
shutil.rmtree(f"{save_dir}/temp")
|
321 |
+
|
322 |
+
def rainbowgram(
|
323 |
+
save_path, out, sr, n_fft=2**13, hop_length=None,
|
324 |
+
f0_input=None, f0_estimate=None, modes=None, colorbar=True,
|
325 |
+
):
|
326 |
+
L = 32
|
327 |
+
if out.shape[-1] > 2*n_fft:
|
328 |
+
hop_length = n_fft // L if hop_length is None else hop_length
|
329 |
+
else:
|
330 |
+
n_fft = out.shape[-1] // 2
|
331 |
+
hop_length = n_fft // L
|
332 |
+
t_max = out.shape[-1] / sr
|
333 |
+
|
334 |
+
out, gain = rms_normalize(out)
|
335 |
+
D = librosa.stft(out, n_fft=n_fft, hop_length=hop_length, pad_mode='reflect')
|
336 |
+
mag, phase = librosa.magphase(D)
|
337 |
+
|
338 |
+
freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
|
339 |
+
times = librosa.times_like(D, sr=sr, hop_length=hop_length)
|
340 |
+
|
341 |
+
phase_exp = 2 * np.pi * np.multiply.outer(freqs, times)
|
342 |
+
unwrapped_phase = np.unwrap((np.angle(phase)-phase_exp) / (L/4), axis=1)
|
343 |
+
unwrapped_phase_diff = np.diff(unwrapped_phase, axis=1, prepend=0)
|
344 |
+
|
345 |
+
alpha = librosa.amplitude_to_db(mag, ref=np.max) / 80 + 1
|
346 |
+
|
347 |
+
#width = 2.5; height = 1.9
|
348 |
+
width = 7; height = 7
|
349 |
+
fig, ax = plt.subplots(figsize=(width,height))
|
350 |
+
spec = librosa.display.specshow(
|
351 |
+
unwrapped_phase_diff, cmap='hsv', alpha=alpha,
|
352 |
+
n_fft=n_fft, hop_length=hop_length, sr=sr, ax=ax,
|
353 |
+
y_axis='log', x_axis='time',
|
354 |
+
)
|
355 |
+
ax.set_facecolor('#000')
|
356 |
+
if colorbar:
|
357 |
+
cbar = fig.colorbar(spec, ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi], ax=ax)
|
358 |
+
cbar.ax.set(yticklabels=['$-\pi$', '$-\pi/2$', "$0$", '$\pi/2$', '$\pi$']);
|
359 |
+
|
360 |
+
def add_plot(freqs, label=None, ls=None, lw=2., dashes=(None,None)):
|
361 |
+
x = np.linspace(1/sr, t_max, freqs.shape[-1])
|
362 |
+
freqs = np.interp(times, x, freqs)
|
363 |
+
line, = ax.plot(times - times[0], freqs, label=label, color='white', lw=lw, ls=ls, dashes=dashes)
|
364 |
+
return line
|
365 |
+
|
366 |
+
freq_ticks = [0, 128, 512, 2048, 8192, sr // 2]
|
367 |
+
time_ticks = [0, 1, 2]
|
368 |
+
if f0_input is not None:
|
369 |
+
add_plot(f0_input, "f0_input", dashes=(10,5))
|
370 |
+
freq_ticks += [f0_input[0]]
|
371 |
+
|
372 |
+
if f0_estimate is not None:
|
373 |
+
add_plot(f0_estimate, "f0_estimate", dashes=(2,5))
|
374 |
+
freq_ticks += [] if f0_input is not None else [f0_estimate[0]]
|
375 |
+
|
376 |
+
if modes is not None:
|
377 |
+
for im, m in enumerate(modes):
|
378 |
+
l = add_plot(m, f"mode {im}")
|
379 |
+
l.set_dashes([5,10,1,10])
|
380 |
+
|
381 |
+
#ax.set_xticks(time_ticks)
|
382 |
+
#ax.set_yticks(freq_ticks)
|
383 |
+
ax.set_xticks([])
|
384 |
+
ax.set_yticks([])
|
385 |
+
ax.xaxis.set_visible(False)
|
386 |
+
ax.yaxis.set_visible(False)
|
387 |
+
|
388 |
+
plt.tight_layout()
|
389 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=-1e-6)
|
390 |
+
plt.clf()
|
391 |
+
plt.close("all")
|
392 |
+
|
393 |
+
def phase_diagram(
|
394 |
+
save_path, x, s,
|
395 |
+
xmin, xmax,
|
396 |
+
dxmin, dxmax,
|
397 |
+
ddxmin, ddxmax,
|
398 |
+
sr, tau=1, label='$u$'):
|
399 |
+
dx = (x[tau:] - x[:-tau]) / (tau / sr)
|
400 |
+
ddx = (x[2*tau:] - 2*x[tau:-tau] + x[:-2*tau]) / (2*tau / sr)
|
401 |
+
|
402 |
+
if s is not None:
|
403 |
+
if s.shape[0] > x.shape[0]:
|
404 |
+
s = s[:x.shape[0]]
|
405 |
+
dsdt = (s[tau:] - s[:-tau]) / (tau / sr)
|
406 |
+
_dsdt = np.mean(np.abs(dsdt), axis=0)
|
407 |
+
spax = np.arange(len(_dsdt))
|
408 |
+
|
409 |
+
if s is not None:
|
410 |
+
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(8,3.5), width_ratios=[4, 1])
|
411 |
+
ax[0,0].axhline(y=0, color='gray', ls='-', lw=0.3)
|
412 |
+
ax[0,0].plot(x, 'k-', lw=0.5)
|
413 |
+
ax[0,0].set_xlim([0,len(x)])
|
414 |
+
ax[0,0].set_ylim([xmin,xmax])
|
415 |
+
ax[0,0].set_xticks([])
|
416 |
+
ax[0,0].set_yticks([])
|
417 |
+
#ax[0,0].set_xlabel('$t$')
|
418 |
+
ax[0,0].set_ylabel(label)
|
419 |
+
|
420 |
+
ax[0,1].axhline(y=0, color='gray', ls='-', lw=0.3)
|
421 |
+
ax[0,1].axvline(x=0, color='gray', ls='-', lw=0.3)
|
422 |
+
ax[0,1].plot(dx, x[tau:], 'k-', lw=0.5)
|
423 |
+
ax[0,1].set_xlim([dxmin,dxmax])
|
424 |
+
ax[0,1].set_ylim([xmin,xmax])
|
425 |
+
ax[0,1].set_xticks([])
|
426 |
+
ax[0,1].set_yticks([])
|
427 |
+
#ax[0,1].set_xlabel('$d$'+label+'$/dt$')
|
428 |
+
|
429 |
+
#state = librosa.display.specshow(s.T, cmap='coolwarm', ax=ax[1,0])
|
430 |
+
state = librosa.display.specshow(dsdt.T, cmap='coolwarm', ax=ax[1,0])
|
431 |
+
maxabs = np.max(np.abs(dsdt))
|
432 |
+
state.set_clim([-maxabs, +maxabs])
|
433 |
+
ax[1,0].set_xlim([0,x.shape[0]])
|
434 |
+
ax[1,0].set_xlabel('$t$')
|
435 |
+
ax[1,0].set_ylabel('$x$')
|
436 |
+
|
437 |
+
_dsdt = np.pad( _dsdt, (1,1))
|
438 |
+
_spax = np.pad( spax, (1,1), mode='edge')
|
439 |
+
ax[1,1].fill_between(+ _dsdt, _spax, alpha=0.2, facecolor='k')
|
440 |
+
ax[1,1].fill_between(- _dsdt, _spax, alpha=0.2, facecolor='k')
|
441 |
+
ax[1,1].axvline(x=0, color='k', ls='-', lw=1.0)
|
442 |
+
ax[1,1].set_ylim([spax[0], spax[-1]])
|
443 |
+
ax[1,1].set_xticks([])
|
444 |
+
ax[1,1].set_yticks([])
|
445 |
+
ax[1,1].set_xlabel('$d$'+label+'$/dt$')
|
446 |
+
else:
|
447 |
+
fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(8,2), width_ratios=[4, 1])
|
448 |
+
ax[0].axhline(y=0, color='gray', ls='-', lw=0.3)
|
449 |
+
ax[0].plot(x, 'k-', lw=0.5)
|
450 |
+
ax[0].set_xlim([0,len(x)])
|
451 |
+
ax[0].set_ylim([xmin,xmax])
|
452 |
+
ax[0].set_xticks([])
|
453 |
+
ax[0].set_yticks([])
|
454 |
+
ax[0].set_xlabel('$t$')
|
455 |
+
ax[0].set_ylabel(label)
|
456 |
+
|
457 |
+
ax[1].axhline(y=0, color='gray', ls='-', lw=0.3)
|
458 |
+
ax[1].axvline(x=0, color='gray', ls='-', lw=0.3)
|
459 |
+
ax[1].plot(dx, x[tau:], 'k-', lw=0.5)
|
460 |
+
ax[1].set_xlim([dxmin,dxmax])
|
461 |
+
ax[1].set_ylim([xmin,xmax])
|
462 |
+
ax[1].set_xticks([])
|
463 |
+
ax[1].set_yticks([])
|
464 |
+
ax[1].set_xlabel('$d$'+label+'$/dt$')
|
465 |
+
|
466 |
+
plt.tight_layout()
|
467 |
+
plt.subplots_adjust(wspace=0.)
|
468 |
+
plt.subplots_adjust(hspace=0.)
|
469 |
+
plt.savefig(save_path, bbox_inches='tight', transparent=True)
|
470 |
+
plt.clf()
|
471 |
+
plt.close("all")
|
472 |
+
|
473 |
+
|
474 |
+
#fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(3.5,2))
|
475 |
+
|
476 |
+
#ax[0].axvline(x=0, color='gray', ls='-', lw=0.3)
|
477 |
+
#ax[0].axhline(y=0, color='gray', ls='-', lw=0.3)
|
478 |
+
#ax[0].plot(x[2*tau:], ddx, 'k-', lw=0.5)
|
479 |
+
#ax[0].set_xlim([xmin,xmax])
|
480 |
+
#ax[0].set_ylim([ddxmin,ddxmax])
|
481 |
+
#ax[0].set_xticks([])
|
482 |
+
#ax[0].set_yticks([])
|
483 |
+
#ax[0].set_xlabel(label)
|
484 |
+
#ax[0].set_ylabel('$d^2$'+label+'$/dt^2$')
|
485 |
+
|
486 |
+
#ax[1].axvline(x=0, color='gray', ls='-', lw=0.3)
|
487 |
+
#ax[1].axhline(y=0, color='gray', ls='-', lw=0.3)
|
488 |
+
#ax[1].plot(dx[tau:], ddx, 'k-', lw=0.5)
|
489 |
+
#ax[1].set_xlim([dxmin, dxmax])
|
490 |
+
#ax[1].set_ylim([ddxmin,ddxmax])
|
491 |
+
#ax[1].set_xticks([])
|
492 |
+
#ax[1].set_yticks([])
|
493 |
+
#ax[1].set_xlabel('$d$'+label+'$/dt$')
|
494 |
+
|
495 |
+
#plt.tight_layout()
|
496 |
+
#plt.subplots_adjust(wspace=0.)
|
497 |
+
#plt.subplots_adjust(hspace=0.)
|
498 |
+
#save_dir = save_path.split('/')[:-1]
|
499 |
+
#save_name = save_path.split('/')[-1]
|
500 |
+
#save_path_2 = '/'.join(save_dir+[save_name.replace('phs', 'dphs')])
|
501 |
+
#plt.savefig(save_path_2, bbox_inches='tight', transparent=True)
|
502 |
+
#plt.clf()
|
503 |
+
#plt.close("all")
|
504 |
+
|
505 |
+
|
506 |
+
def xt_grid_embedding(save_path, x, t, embed_dim=32, t_gain=1e-6, x_gain=1e-2):
|
507 |
+
t = t * 1000
|
508 |
+
|
509 |
+
Bs, _, Nx = x.shape
|
510 |
+
Bs, Nt, _ = t.shape
|
511 |
+
t_embd = sinusoidal_embedding(t.unsqueeze(-1), n=embed_dim, gain=t_gain) # (Bs, 1,Nx,1,embed_dim)
|
512 |
+
x_embd = sinusoidal_embedding(x.unsqueeze(-1), n=embed_dim, gain=x_gain) # (Bs,Nt, 1,1,embed_dim)
|
513 |
+
|
514 |
+
t_axis = t.squeeze().detach().cpu().numpy()
|
515 |
+
x_axis = x.squeeze().detach().cpu().numpy()
|
516 |
+
t_embd = t_embd.squeeze().detach().cpu().numpy()
|
517 |
+
x_embd = x_embd.squeeze().detach().cpu().numpy()
|
518 |
+
assert len(list(t_embd.shape)) == 2, t_embd.shape
|
519 |
+
assert len(list(x_embd.shape)) == 2, x_embd.shape
|
520 |
+
e = np.arange(embed_dim)
|
521 |
+
|
522 |
+
fig, ax = plt.subplots(figsize=(13,7), nrows=1, ncols=2)
|
523 |
+
librosa.display.specshow(t_embd, ax=ax[0], x_coords=e, y_coords=t_axis)
|
524 |
+
librosa.display.specshow(x_embd, ax=ax[1], x_coords=e, y_coords=x_axis)
|
525 |
+
ax[0].set_title("t embed")
|
526 |
+
ax[0].set_xlabel("embedding dim")
|
527 |
+
ax[0].set_ylabel("time")
|
528 |
+
ax[0].set_yticks(t_axis[0::10])
|
529 |
+
|
530 |
+
ax[1].set_title("x embed")
|
531 |
+
ax[1].set_xlabel("embedding dim")
|
532 |
+
ax[1].set_ylabel("space")
|
533 |
+
ax[1].set_yticks(x_axis[0::10])
|
534 |
+
ax[1].yaxis.set_label_position("right")
|
535 |
+
ax[1].yaxis.tick_right()
|
536 |
+
|
537 |
+
plt.tight_layout()
|
538 |
+
plt.subplots_adjust(wspace=0.)
|
539 |
+
plt.subplots_adjust(hspace=0.)
|
540 |
+
plt.savefig(save_path)
|
541 |
+
plt.clf()
|
542 |
+
plt.close("all")
|
543 |
+
|
544 |
+
def logedc(save_path, logedc, tmax):
|
545 |
+
time = np.linspace(0, tmax, logedc.shape[0])
|
546 |
+
|
547 |
+
fig, ax = plt.subplots(figsize=(3,3))
|
548 |
+
ax.plot(time, logedc)
|
549 |
+
ax.set_xlabel("Time (s)")
|
550 |
+
ax.set_ylabel("Energy (dB)")
|
551 |
+
|
552 |
+
plt.tight_layout()
|
553 |
+
plt.savefig(save_path)
|
554 |
+
plt.clf()
|
555 |
+
plt.close("all")
|
556 |
+
|
557 |
+
def f0curve(save_path, f0_input, f0_estimate, first_mode, tmax):
|
558 |
+
time = np.linspace(0, tmax, len(f0_estimate))
|
559 |
+
|
560 |
+
fig, ax = plt.subplots(figsize=(3,3))
|
561 |
+
ax.plot(time, f0_input, label='$f_0$')
|
562 |
+
ax.plot(time, f0_estimate, label='$f_0^{(\\tt est)}$')
|
563 |
+
ax.plot(time, first_mode, label='$\hat{f_0}$')
|
564 |
+
ax.set_xlabel("Time (s)")
|
565 |
+
ax.set_ylabel("Frequency (Hz)")
|
566 |
+
ax.set_ylim(0, 200)
|
567 |
+
|
568 |
+
plt.legend()
|
569 |
+
plt.tight_layout()
|
570 |
+
plt.savefig(save_path)
|
571 |
+
plt.clf()
|
572 |
+
plt.close("all")
|
573 |
+
|
574 |
+
def spectrum(save_path, out, f0_input, f0_estimate, modes, sr, n_fft=2**14, ylabel=None):
|
575 |
+
t_max = out.shape[-1] / sr
|
576 |
+
n_fft = min(n_fft, out.shape[-1])
|
577 |
+
cr = int(f0_estimate.shape[-1] / t_max) # crepe framerate
|
578 |
+
simulated = out[-n_fft:]
|
579 |
+
f0_input = f0_input[-1]
|
580 |
+
f0_estimate = f0_estimate[-1]
|
581 |
+
modes = [m[-1] for m in modes]
|
582 |
+
|
583 |
+
simulated_fr = 20 * np.log10(np.abs(np.fft.rfft(simulated, n_fft)))
|
584 |
+
freqs = np.linspace(0, sr/2 / 1000, int(n_fft/2+1))
|
585 |
+
|
586 |
+
n_freqs = 1024
|
587 |
+
|
588 |
+
fig, ax = plt.subplots(figsize=(4,2))
|
589 |
+
|
590 |
+
lw = 0.7
|
591 |
+
ax.plot(freqs[:n_freqs], simulated_fr[:n_freqs], 'k', lw=1.)
|
592 |
+
ax.axvline(x=f0_input / 1000, c='r', ls='-', lw=lw, label='$f_0$')
|
593 |
+
ax.axvline(x=f0_estimate / 1000, c='g', ls='--', lw=lw, label='$f_0^{(\\tt est)}$')
|
594 |
+
for i, m in enumerate(modes):
|
595 |
+
if i == 0:
|
596 |
+
ax.axvline(x=m / 1000, c='b', ls='-.', lw=lw, label='$\hat{f_p}$')
|
597 |
+
else:
|
598 |
+
ax.axvline(x=m / 1000, c='b', ls='-.', lw=lw)
|
599 |
+
|
600 |
+
ax.set_xticks([0, 0.5, 1, 1.5, 2])
|
601 |
+
plt.xlim([0, 2])
|
602 |
+
plt.xlabel('Frequency (kHz)')
|
603 |
+
plt.ylabel(ylabel)
|
604 |
+
|
605 |
+
plt.legend(ncol=3, fancybox=True)
|
606 |
+
plt.tight_layout()
|
607 |
+
plt.savefig(save_path, bbox_inches='tight')
|
608 |
+
plt.clf()
|
609 |
+
plt.close("all")
|
610 |
+
|
611 |
+
|
612 |
+
def spectrum_uz(save_path, uout, zout, f0_input, f0_estimate, modes, sr, n_fft=2**14):
|
613 |
+
t_max = uout.shape[-1] / sr
|
614 |
+
n_fft = min(n_fft, uout.shape[-1])
|
615 |
+
cr = int(f0_estimate.shape[-1] / t_max) # crepe framerate
|
616 |
+
simulated_u = uout[-n_fft:]
|
617 |
+
simulated_z = zout[-n_fft:]
|
618 |
+
f0_input = f0_input[-1]
|
619 |
+
f0_estimate = f0_estimate[-1]
|
620 |
+
modes = [m[-1] for m in modes]
|
621 |
+
|
622 |
+
simulated_fr_u = 20 * np.log10(np.abs(np.fft.rfft(simulated_u, n_fft)))
|
623 |
+
simulated_fr_z = 20 * np.log10(np.abs(np.fft.rfft(simulated_z, n_fft)))
|
624 |
+
freqs = np.linspace(0, sr/2 / 1000, int(n_fft/2+1))
|
625 |
+
|
626 |
+
n_freqs = 1024
|
627 |
+
|
628 |
+
fig, ax = plt.subplots(figsize=(2.5,2), ncols=1, nrows=2)
|
629 |
+
#fig, ax = plt.subplots(figsize=(4,2), ncols=1, nrows=2)
|
630 |
+
|
631 |
+
lw = 1.
|
632 |
+
lw_fr = .5
|
633 |
+
al = .5
|
634 |
+
ax[0].axhline(y=0, c='k', lw=0.5, alpha=al)
|
635 |
+
ax[0].plot(freqs[:n_freqs], simulated_fr_u[:n_freqs], 'k', lw=lw_fr)
|
636 |
+
ax[0].axvline(x=f0_input / 1000, c='r', ls='-', lw=lw, label='$f_0$', alpha=al)
|
637 |
+
ax[0].axvline(x=f0_estimate / 1000, c='g', ls='--', lw=lw, label='$f_0^{(\\tt est)}$', alpha=al)
|
638 |
+
for i, m in enumerate(modes):
|
639 |
+
if i == 0:
|
640 |
+
ax[0].axvline(x=m / 1000, c='b', ls=':', lw=lw, label='$\hat{f_p}$', alpha=al)
|
641 |
+
else:
|
642 |
+
ax[0].axvline(x=m / 1000, c='b', ls=':', lw=lw, alpha=al)
|
643 |
+
ax[0].set_xticks([0, 0.5, 1, 1.5, 2])
|
644 |
+
ax[0].set_xlim([0, 2])
|
645 |
+
ax[0].set_ylabel('$|u|$')
|
646 |
+
ax[0].xaxis.set_label_position('top')
|
647 |
+
ax[0].yaxis.tick_right()
|
648 |
+
ax[0].xaxis.tick_top()
|
649 |
+
|
650 |
+
ax[1].axhline(y=0, c='k', lw=0.3, alpha=al)
|
651 |
+
ax[1].plot(freqs[:n_freqs], simulated_fr_z[:n_freqs], 'k', lw=lw_fr)
|
652 |
+
ax[1].axvline(x=f0_input / 1000, c='r', ls='-', lw=lw, label='$f_0$', alpha=al)
|
653 |
+
ax[1].axvline(x=f0_estimate / 1000, c='g', ls='--', lw=lw, label='$f_0^{(\\tt est)}$', alpha=al)
|
654 |
+
for i, m in enumerate(modes):
|
655 |
+
if i == 0:
|
656 |
+
ax[1].axvline(x=m / 1000, c='b', ls=':', lw=lw, label='$\hat{f_p}$', alpha=al)
|
657 |
+
else:
|
658 |
+
ax[1].axvline(x=m / 1000, c='b', ls=':', lw=lw, alpha=al)
|
659 |
+
ax[1].set_xticks([])
|
660 |
+
ax[1].set_xlim([0, 2])
|
661 |
+
ax[1].set_xlabel('Frequency (kHz)')
|
662 |
+
ax[1].set_ylabel('$|\zeta|$')
|
663 |
+
ax[1].yaxis.tick_right()
|
664 |
+
#ax[1].xaxis.set_label_coords(0.2, -0.05)
|
665 |
+
#plt.legend(loc='lower center', bbox_to_anchor=(0.7,-0.4), ncol=3, fancybox=True, handletextpad=0.1, columnspacing=1.)
|
666 |
+
#ax[1].xaxis.set_label_coords(0.2, -0.1)
|
667 |
+
#plt.legend(loc='lower center', bbox_to_anchor=(0.7,-0.8), ncol=3, fancybox=True, handletextpad=0.1, columnspacing=1.)
|
668 |
+
ax[1].xaxis.set_label_coords(0.3, -0.1)
|
669 |
+
plt.legend(loc='lower center', bbox_to_anchor=(.95,-0.5), ncol=3, fancybox=True, handlelength=1., handletextpad=0.1, columnspacing=.5, fontsize=7)
|
670 |
+
|
671 |
+
plt.tight_layout()
|
672 |
+
plt.subplots_adjust(wspace=0.)
|
673 |
+
plt.subplots_adjust(hspace=0.)
|
674 |
+
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=-1e-6)
|
675 |
+
plt.clf()
|
676 |
+
plt.close("all")
|
677 |
+
|
678 |
+
|
679 |
+
def scatter_xy(save_path, x, y_dict, xlabel, ylabel, xticks=[], yticks=[]):
|
680 |
+
fig, ax = plt.subplots(figsize=(2.5,2.5))
|
681 |
+
for y_label in y_dict.keys():
|
682 |
+
ax.scatter(x, y_dict[y_label], label=y_label, s=1.)
|
683 |
+
ax.set_xlabel(xlabel)
|
684 |
+
ax.set_ylabel(ylabel)
|
685 |
+
|
686 |
+
ax.set_xticks(xticks)
|
687 |
+
ax.set_yticks(yticks)
|
688 |
+
|
689 |
+
plt.legend()
|
690 |
+
plt.tight_layout()
|
691 |
+
plt.savefig(save_path, bbox_inches='tight', transparent=True)
|
692 |
+
plt.clf()
|
693 |
+
plt.close("all")
|
694 |
+
|
695 |
+
|
696 |
+
def scatter_kappa(save_path, total_summary, ss=.3):
|
697 |
+
f0_diffs, f0_ground, kappa, alpha = total_summary
|
698 |
+
|
699 |
+
def moving_average(x, n):
|
700 |
+
assert n % 2 == 1, n
|
701 |
+
x = np.pad(x, (n//2, n//2), 'symmetric')
|
702 |
+
return np.convolve(x, np.ones(n) / n, 'valid')
|
703 |
+
|
704 |
+
sorted_kf = sorted(zip(kappa, f0_ground))
|
705 |
+
sorted_kappa = [k for k, f in sorted_kf]
|
706 |
+
sorted_f0_ground = [f for k, f in sorted_kf]
|
707 |
+
sorted_kappa = sorted_kappa[0::40] + [sorted_kappa[-1]]
|
708 |
+
sorted_f0_ground = sorted_f0_ground[0::40] + [sorted_f0_ground[-1]]
|
709 |
+
|
710 |
+
diff_max = max(f0_diffs) + 3.
|
711 |
+
xticks = [5,10,15,20]
|
712 |
+
yticks = [0,10,20,30,40,50,60]
|
713 |
+
|
714 |
+
fig, ax = plt.subplots(figsize=(2.5,2), nrows=1, ncols=1)
|
715 |
+
#cm = plt.cm.get_cmap('RdYlBu')
|
716 |
+
cm = plt.cm.get_cmap('plasma')
|
717 |
+
|
718 |
+
ax.plot(sorted_kappa, sorted_f0_ground, 'k-', lw=1.0, alpha=0.5)
|
719 |
+
sc = ax.scatter(kappa, f0_diffs, c=alpha, s=ss,
|
720 |
+
vmin=min(alpha), vmax=max(alpha), cmap=cm)
|
721 |
+
|
722 |
+
cbar = plt.colorbar(sc)
|
723 |
+
cbar.ax.set_title(r'$\alpha$')
|
724 |
+
cbar.ax.set_yticks([1,10,20,25])
|
725 |
+
ax.set_xticks(xticks)
|
726 |
+
ax.set_yticks(yticks)
|
727 |
+
ax.set_ylim([0,60])
|
728 |
+
for xt in xticks: ax.axvline(xt, c='k', ls='-', lw=0.5, alpha=0.3)
|
729 |
+
for yt in yticks: ax.axhline(yt, c='k', ls='-', lw=0.5, alpha=0.3)
|
730 |
+
ax.set_xlabel('$\kappa$')
|
731 |
+
ax.set_ylabel(r'$|f_0^{(\tt est)} - f_0|$ (Hz)')
|
732 |
+
ax.xaxis.tick_top()
|
733 |
+
|
734 |
+
plt.tight_layout()
|
735 |
+
#plt.subplots_adjust(wspace=0.)
|
736 |
+
#plt.subplots_adjust(hspace=0.)
|
737 |
+
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=-1e-5)
|
738 |
+
plt.clf()
|
739 |
+
plt.close("all")
|
740 |
+
|
741 |
+
|
742 |
+
|
743 |
+
def scatter_pluck(save_path, total_summary, ss=.3, al=0.7):
|
744 |
+
cmap = {
|
745 |
+
'$|f_0^{(\\tt est)} - f_0|$' : 'orchid',
|
746 |
+
'$|f_0^{(\\tt est)} - \hat{f_0}|$' : 'cadetblue',
|
747 |
+
}
|
748 |
+
|
749 |
+
f0_diffs, kappa, alpha, p_x, p_a = total_summary
|
750 |
+
|
751 |
+
diff_max = max([max(item) for k, item in f0_diffs.items()]) + 3.
|
752 |
+
ncols = 3 if alpha is None else 4
|
753 |
+
|
754 |
+
fig, ax = plt.subplots(figsize=(4., 2), nrows=1, ncols=ncols)
|
755 |
+
# kappa
|
756 |
+
for y_label in f0_diffs.keys():
|
757 |
+
ax[0].scatter(kappa, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
|
758 |
+
ax[0].axvline(x=5.88, c='k', ls='--', lw=0.5)
|
759 |
+
#ax[0].axhline(y=6, c='k', ls='--', lw=0.5)
|
760 |
+
#ax[0].axhline(y=1, c='k', ls='--', lw=0.5)
|
761 |
+
ax[0].set_xlabel('$\kappa$')
|
762 |
+
ax[0].set_ylabel('Detune')
|
763 |
+
#ax[0].set_ylim([0, 10])
|
764 |
+
ax[0].set_ylim([0, diff_max])
|
765 |
+
ax[0].set_xticks([2,5,8])
|
766 |
+
ax[0].set_yticks([])
|
767 |
+
ax[0].xaxis.tick_top()
|
768 |
+
|
769 |
+
# p_x
|
770 |
+
for y_label in f0_diffs.keys():
|
771 |
+
ax[1].scatter(p_x, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
|
772 |
+
#ax[1].axhline(y=6, c='k', ls='--', lw=0.5)
|
773 |
+
#ax[1].axhline(y=1, c='k', ls='--', lw=0.5)
|
774 |
+
ax[1].set_xlabel('$p_x$')
|
775 |
+
ax[1].set_ylim([0, diff_max])
|
776 |
+
ax[1].set_xticks([-0.5, 0])
|
777 |
+
ax[1].set_yticks([])
|
778 |
+
ax[1].xaxis.tick_top()
|
779 |
+
ax[1].yaxis.tick_right()
|
780 |
+
|
781 |
+
# p_a
|
782 |
+
p_a = [x * 1e3 for x in p_a]
|
783 |
+
for y_label in f0_diffs.keys():
|
784 |
+
ax[2].scatter(p_a, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
|
785 |
+
#ax[2].axhline(y=6, c='k', ls='--', lw=0.5)
|
786 |
+
#ax[2].axhline(y=1, c='k', ls='--', lw=0.5)
|
787 |
+
ax[2].set_xlabel('$p_a\\times10^{3}$')
|
788 |
+
ax[2].set_ylim([0, diff_max])
|
789 |
+
ax[2].set_xticks([1, 4, 7, 10])
|
790 |
+
ax[2].set_yticks([0,5,10])
|
791 |
+
ax[2].xaxis.tick_top()
|
792 |
+
ax[2].yaxis.tick_right()
|
793 |
+
|
794 |
+
|
795 |
+
# alpha
|
796 |
+
if alpha is not None:
|
797 |
+
for y_label in f0_diffs.keys():
|
798 |
+
ax[3].scatter(alpha, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
|
799 |
+
ax[3].axhline(y=6, c='k', ls='--', lw=0.5)
|
800 |
+
ax[3].axhline(y=1, c='k', ls='--', lw=0.5)
|
801 |
+
ax[3].set_xlabel('$\\alpha$')
|
802 |
+
ax[3].set_ylim([0, diff_max])
|
803 |
+
#ax[3].set_xticks([1,2,3,4])
|
804 |
+
ax[2].set_yticks([])
|
805 |
+
ax[3].set_yticks([0,5,10])
|
806 |
+
ax[3].xaxis.tick_top()
|
807 |
+
|
808 |
+
plt.tight_layout()
|
809 |
+
plt.legend(loc='lower center', bbox_to_anchor=(-0.5, -1.2), ncol=2, fancybox=True, handletextpad=0.02, columnspacing=.2, markerscale=5., fontsize=7)
|
810 |
+
plt.subplots_adjust(wspace=0.)
|
811 |
+
plt.subplots_adjust(hspace=0.)
|
812 |
+
plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=-1e-5)
|
813 |
+
plt.clf()
|
814 |
+
plt.close("all")
|
815 |
+
|
816 |
+
|
817 |
+
|
818 |
+
def time_experiment(save_path, gpu_summary, cpu_summary):
|
819 |
+
|
820 |
+
n_criteria = len(list(gpu_summary.keys()))
|
821 |
+
fig, ax = plt.subplots(figsize=(5, 1.66), nrows=1, ncols=n_criteria)
|
822 |
+
|
823 |
+
config = {
|
824 |
+
'Batch size' : [4, 16, 64, 256, 1024],
|
825 |
+
'$N_t$' : [0.25, 0.50, 1.00, 2.00, 4.00],
|
826 |
+
'$N_x^{(\\tt t)}+N_x^{(\\tt l)}$' : [20, 40, 80, 160, 320],
|
827 |
+
'$N_x^{(\\tt l)}$' : [1, 2, 3, 4],
|
828 |
+
}
|
829 |
+
xlims = {
|
830 |
+
'Batch size' : [2,1800],
|
831 |
+
'$N_t$' : [6000, 300000],
|
832 |
+
#'$N_x^{(\\tt t)}$' : [15, 160],
|
833 |
+
'$N_x^{(\\tt t)}+N_x^{(\\tt l)}$' : [70, 1900],
|
834 |
+
'$N_x^{(\\tt l)}$' : [15, 160],
|
835 |
+
}
|
836 |
+
|
837 |
+
def f0_to_NtNl(f0, k=1/48000, theta_t=0.5 + 2/(np.pi**2), kappa_rel=0.03):
|
838 |
+
gamma = 2*f0
|
839 |
+
kappa = gamma * kappa_rel
|
840 |
+
IHP = (np.pi * kappa / gamma)**2 # inharmonicity parameter (>0); eq 7.21
|
841 |
+
K = pow(IHP, .5) * (gamma / np.pi) # set parameters
|
842 |
+
h = pow( \
|
843 |
+
(gamma**2 * k**2 + pow(gamma**4 * k**4 + 16 * K**2 * k**2 * (2 * theta_t - 1), .5)) \
|
844 |
+
/ (2 * (2 * theta_t - 1)) \
|
845 |
+
, .5)
|
846 |
+
N_t = int(1/h)
|
847 |
+
alpha = 1
|
848 |
+
h = gamma * alpha * k
|
849 |
+
N_l = int(1/h)
|
850 |
+
return N_t + N_l
|
851 |
+
def alpha_to_Nl(alpha, gamma=600, k=1/48000):
|
852 |
+
h = gamma * alpha * k
|
853 |
+
N_l = int(1/h)
|
854 |
+
return N_l
|
855 |
+
|
856 |
+
for i, criterion in enumerate(config.keys()):
|
857 |
+
if criterion == '$N_t$':
|
858 |
+
config[criterion] = [int(c * 48000) for c in config[criterion]]
|
859 |
+
if criterion == '$N_x^{(\\tt t)}+N_x^{(\\tt l)}$':
|
860 |
+
config[criterion] = [f0_to_NtNl(c) for c in config[criterion]]
|
861 |
+
if criterion == '$N_x^{(\\tt l)}$':
|
862 |
+
config[criterion] = [alpha_to_Nl(c) for c in config[criterion]]
|
863 |
+
|
864 |
+
print(config)
|
865 |
+
|
866 |
+
for i, criterion in enumerate(gpu_summary.keys()):
|
867 |
+
conf_list = config[criterion]
|
868 |
+
gpu_times = gpu_summary[criterion]
|
869 |
+
cpu_times = cpu_summary[criterion]
|
870 |
+
|
871 |
+
if i == 0:
|
872 |
+
# divide by number of batch
|
873 |
+
#gpu_times = [gpu_times[k] / conf_list[k] for k in range(len(gpu_times))]
|
874 |
+
#cpu_times = [cpu_times[k] / conf_list[k] for k in range(len(cpu_times))]
|
875 |
+
pass
|
876 |
+
elif i > 1:
|
877 |
+
conf_list = list(reversed(conf_list))
|
878 |
+
gpu_times = list(reversed(gpu_times))
|
879 |
+
cpu_times = list(reversed(cpu_times))
|
880 |
+
gpu_times = [gpu_times[k] / gpu_times[0] for k in range(len(gpu_times))]
|
881 |
+
cpu_times = [cpu_times[k] / cpu_times[0] for k in range(len(cpu_times))]
|
882 |
+
|
883 |
+
lin_times = [conf_list[k] / conf_list[0] for k in range(len(gpu_times))]
|
884 |
+
|
885 |
+
thicklw = 0.8
|
886 |
+
ax[i].axhline(y=100, c='lightgray', lw=thicklw, ls=':')
|
887 |
+
ax[i].axhline(y=10, c='lightgray', lw=thicklw, ls=':')
|
888 |
+
ax[i].axhline(y=1, c='lightgray', lw=thicklw, ls='-')
|
889 |
+
|
890 |
+
ax[i].plot(conf_list[:len(cpu_times)], cpu_times, 'kD--', lw=.9, label="CPU", mfc='lightgray')
|
891 |
+
ax[i].plot(conf_list[:len(gpu_times)], gpu_times, 'ko-', lw=.9, label="GPU", mfc='white')
|
892 |
+
|
893 |
+
ax[i].fill_between(conf_list[:len(cpu_times)], lin_times, alpha=.2)
|
894 |
+
|
895 |
+
ax[i].set_xlabel(criterion)
|
896 |
+
if i == 0:
|
897 |
+
ax[i].set_ylabel('Relative time')
|
898 |
+
|
899 |
+
ax[i].set_xscale('log')
|
900 |
+
ax[i].set_yscale('log')
|
901 |
+
|
902 |
+
ax[i].set_ylim([0.5, 1e3])
|
903 |
+
|
904 |
+
ax[i].set_xlim(xlims[criterion])
|
905 |
+
ax[i].xaxis.set_label_position('top')
|
906 |
+
ax[i].yaxis.tick_right()
|
907 |
+
if i < len(list(gpu_summary.keys()))-1:
|
908 |
+
ax[i].set_yticks([])
|
909 |
+
else:
|
910 |
+
ax[i].set_yticks([1, 10, 100, 1000])
|
911 |
+
|
912 |
+
plt.tight_layout()
|
913 |
+
#plt.legend(loc='lower center', bbox_to_anchor=(-0.5, -0.75), ncol=2, fancybox=True, handletextpad=0.1, columnspacing=1.)
|
914 |
+
#plt.legend(loc='lower right', ncol=2, fancybox=True)
|
915 |
+
plt.legend(loc='upper right', ncol=2, fancybox=True)
|
916 |
+
plt.subplots_adjust(wspace=0.)
|
917 |
+
plt.subplots_adjust(hspace=0.)
|
918 |
+
plt.savefig(save_path, bbox_inches='tight', transparent=True)
|
919 |
+
plt.clf()
|
920 |
+
plt.close("all")
|
921 |
+
|
922 |
+
|
923 |
+
def est_tar_specs(est, tar, inp, plot_path, wave_path, sr=16000):
|
924 |
+
data = []
|
925 |
+
batch_size = est["wav"].shape[0]
|
926 |
+
for b in range(batch_size):
|
927 |
+
logspecs = []
|
928 |
+
difspecs = []
|
929 |
+
|
930 |
+
nrows = 4; ncols = 2
|
931 |
+
height = 8; widths = 7
|
932 |
+
specfig, ax = plt.subplots(nrows, ncols, figsize=(widths,height))
|
933 |
+
|
934 |
+
diff_0 = tar["logmag"][b] - est["logmag"][b]
|
935 |
+
logspecs.append(
|
936 |
+
librosa.display.specshow(
|
937 |
+
inp["logmag"][b].numpy().T, cmap='magma', ax=ax[0,0]))
|
938 |
+
logspecs.append(
|
939 |
+
librosa.display.specshow(
|
940 |
+
est["logmag"][b].numpy().T, cmap='magma', ax=ax[1,0]))
|
941 |
+
logspecs.append(
|
942 |
+
librosa.display.specshow(
|
943 |
+
tar["logmag"][b].numpy().T, cmap='magma', ax=ax[2,0]))
|
944 |
+
difspecs.append(
|
945 |
+
librosa.display.specshow(
|
946 |
+
diff_0.numpy().T, cmap='bwr', ax=ax[3,0]))
|
947 |
+
|
948 |
+
diff_0 = tar["logmel"][b] - est["logmel"][b]
|
949 |
+
logspecs.append(
|
950 |
+
librosa.display.specshow(
|
951 |
+
inp["logmel"][b].numpy().T, cmap='magma',ax=ax[0,1]))
|
952 |
+
logspecs.append(
|
953 |
+
librosa.display.specshow(
|
954 |
+
est["logmel"][b].numpy().T, cmap='magma',ax=ax[1,1]))
|
955 |
+
logspecs.append(
|
956 |
+
librosa.display.specshow(
|
957 |
+
tar["logmel"][b].numpy().T, cmap='magma',ax=ax[2,1]))
|
958 |
+
difspecs.append(
|
959 |
+
librosa.display.specshow(
|
960 |
+
diff_0.numpy().T, cmap='bwr', ax=ax[3,1]))
|
961 |
+
|
962 |
+
for spec in logspecs:
|
963 |
+
spec.set_clim([-60, 30])
|
964 |
+
for spec in difspecs:
|
965 |
+
spec.set_clim([-20, 20])
|
966 |
+
|
967 |
+
titles = ['Analytic', 'Estimate', 'Original', 'Difference']
|
968 |
+
for i, title in enumerate(titles):
|
969 |
+
ax[i,0].set_ylabel(title)
|
970 |
+
|
971 |
+
specfig.tight_layout()
|
972 |
+
specfig.subplots_adjust(wspace=0)
|
973 |
+
specfig.subplots_adjust(hspace=0)
|
974 |
+
|
975 |
+
specfig.savefig(plot_path)
|
976 |
+
plt.close('all')
|
977 |
+
plt.clf()
|
978 |
+
|
979 |
+
inp_wav = inp["wav"][b].squeeze()
|
980 |
+
sf.write(wave_path.replace('.wav', f"-{b}-inp.wav"), inp_wav, samplerate=sr)
|
981 |
+
|
982 |
+
est_wav = est["wav"][b].squeeze()
|
983 |
+
sf.write(wave_path.replace('.wav', f"-{b}-est.wav"), est_wav, samplerate=sr)
|
984 |
+
tar_wav = tar["wav"][b].squeeze()
|
985 |
+
sf.write(wave_path.replace('.wav', f"-{b}-tar.wav"), tar_wav, samplerate=sr)
|
986 |
+
|
987 |
+
d = [ wandb.Image(specfig) ]
|
988 |
+
d += [ wandb.Audio(inp_wav, sample_rate=sr) ]
|
989 |
+
d += [ wandb.Audio(est_wav, sample_rate=sr) ]
|
990 |
+
d += [ wandb.Audio(tar_wav, sample_rate=sr) ]
|
991 |
+
data.append(d)
|
992 |
+
|
993 |
+
columns = ["spec"]
|
994 |
+
columns += ["analytic", "estimate", "original"]
|
995 |
+
return {
|
996 |
+
"columns": columns,
|
997 |
+
"data": data,
|
998 |
+
}
|
999 |
+
|
1000 |
+
|
1001 |
+
def rde_specs(factors, est, sim, plot_path, wave_path, sr=16000):
|
1002 |
+
data = []
|
1003 |
+
num_factors = len(factors)
|
1004 |
+
# plot_path = f'test/plot/rde.png'
|
1005 |
+
mag_path = plot_path.replace('rde.png', 'rde-mag.png')
|
1006 |
+
mel_path = plot_path.replace('rde.png', 'rde-mel.png')
|
1007 |
+
seu_path = plot_path.replace('rde.png', 'rde-state-pinn-u.png')
|
1008 |
+
sez_path = plot_path.replace('rde.png', 'rde-state-pinn-z.png')
|
1009 |
+
ssu_path = plot_path.replace('rde.png', 'rde-state-fdtd-u.png')
|
1010 |
+
ssz_path = plot_path.replace('rde.png', 'rde-state-fdtd-z.png')
|
1011 |
+
|
1012 |
+
#==============================
|
1013 |
+
# plot logmag
|
1014 |
+
#==============================
|
1015 |
+
specs = []
|
1016 |
+
magfig, ax = plt.subplots(nrows=num_factors, ncols=2, figsize=(5,7))
|
1017 |
+
for i in range(num_factors):
|
1018 |
+
specs.append(librosa.display.specshow(
|
1019 |
+
sim["logmag"][i].numpy().T, cmap='magma',ax=ax[i,0]))
|
1020 |
+
specs.append(librosa.display.specshow(
|
1021 |
+
est["logmag"][i].numpy().T, cmap='magma',ax=ax[i,1]))
|
1022 |
+
for spec in specs: spec.set_clim([-60, 30])
|
1023 |
+
for i, fc in enumerate(factors): ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
|
1024 |
+
ax[0,0].set_title('FDTD')
|
1025 |
+
ax[0,1].set_title('PINN')
|
1026 |
+
magfig.tight_layout()
|
1027 |
+
magfig.subplots_adjust(wspace=0)
|
1028 |
+
magfig.subplots_adjust(hspace=0)
|
1029 |
+
magfig.savefig(mag_path)
|
1030 |
+
plt.close('all')
|
1031 |
+
plt.clf()
|
1032 |
+
|
1033 |
+
#==============================
|
1034 |
+
# plot logmel
|
1035 |
+
#==============================
|
1036 |
+
specs = []
|
1037 |
+
melfig, ax = plt.subplots(nrows=num_factors, ncols=2, figsize=(5,7))
|
1038 |
+
for i in range(num_factors):
|
1039 |
+
specs.append(librosa.display.specshow(
|
1040 |
+
sim["logmel"][i].numpy().T, cmap='magma',ax=ax[i,0]))
|
1041 |
+
specs.append(librosa.display.specshow(
|
1042 |
+
est["logmel"][i].numpy().T, cmap='magma',ax=ax[i,1]))
|
1043 |
+
for spec in specs: spec.set_clim([-60, 30])
|
1044 |
+
for i, fc in enumerate(factors): ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
|
1045 |
+
ax[0,0].set_title('FDTD')
|
1046 |
+
ax[0,1].set_title('PINN')
|
1047 |
+
melfig.tight_layout()
|
1048 |
+
melfig.subplots_adjust(wspace=0)
|
1049 |
+
melfig.subplots_adjust(hspace=0)
|
1050 |
+
melfig.savefig(mel_path)
|
1051 |
+
plt.close('all')
|
1052 |
+
plt.clf()
|
1053 |
+
|
1054 |
+
#==============================
|
1055 |
+
# plot state
|
1056 |
+
#==============================
|
1057 |
+
u_states = []; dustates = []
|
1058 |
+
z_states = []; dzstates = []
|
1059 |
+
|
1060 |
+
eu_fig, eu_ax = plt.subplots(num_factors, 2, figsize=(7,7))
|
1061 |
+
ez_fig, ez_ax = plt.subplots(num_factors, 2, figsize=(7,7))
|
1062 |
+
su_fig, su_ax = plt.subplots(num_factors, 2, figsize=(7,7))
|
1063 |
+
sz_fig, sz_ax = plt.subplots(num_factors, 2, figsize=(7,7))
|
1064 |
+
|
1065 |
+
u_max = 0
|
1066 |
+
z_max = 0
|
1067 |
+
|
1068 |
+
cm = 'coolwarm'
|
1069 |
+
for i, fc in enumerate(factors):
|
1070 |
+
e_dif = est["state"][i] - est["state"][-1]
|
1071 |
+
s_dif = sim["state"][i] - sim["state"][-1]
|
1072 |
+
Nt = int(sr * 30 / 1000)
|
1073 |
+
u_states.append(librosa.display.specshow(sim["state"][i][:Nt,:,0].numpy().T, cmap=cm,ax=su_ax[i,0]))
|
1074 |
+
u_states.append(librosa.display.specshow(est["state"][i][:Nt,:,0].numpy().T, cmap=cm,ax=eu_ax[i,0]))
|
1075 |
+
dustates.append(librosa.display.specshow(s_dif[:Nt,:,0].numpy().T, cmap=cm,ax=su_ax[i,1]))
|
1076 |
+
dustates.append(librosa.display.specshow(e_dif[:Nt,:,0].numpy().T, cmap=cm,ax=eu_ax[i,1]))
|
1077 |
+
|
1078 |
+
z_states.append(librosa.display.specshow(sim["state"][i][:Nt,:,1].numpy().T, cmap=cm,ax=sz_ax[i,0]))
|
1079 |
+
z_states.append(librosa.display.specshow(est["state"][i][:Nt,:,1].numpy().T, cmap=cm,ax=ez_ax[i,0]))
|
1080 |
+
dzstates.append(librosa.display.specshow(s_dif[:Nt,:,1].numpy().T, cmap=cm,ax=sz_ax[i,1]))
|
1081 |
+
dzstates.append(librosa.display.specshow(e_dif[:Nt,:,1].numpy().T, cmap=cm,ax=ez_ax[i,1]))
|
1082 |
+
|
1083 |
+
u_max = max(u_max, sim["state"][i][:Nt,:,0].abs().max(), est["state"][i][:Nt,:,0].abs().max())
|
1084 |
+
z_max = max(z_max, sim["state"][i][:Nt,:,1].abs().max(), est["state"][i][:Nt,:,1].abs().max())
|
1085 |
+
su_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
|
1086 |
+
eu_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
|
1087 |
+
sz_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
|
1088 |
+
ez_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
|
1089 |
+
|
1090 |
+
for stat in u_states: stat.set_clim([-u_max, u_max])
|
1091 |
+
for stat in z_states: stat.set_clim([-z_max, z_max])
|
1092 |
+
for stat in dustates: stat.set_clim([-u_max/10, u_max/10])
|
1093 |
+
for stat in dzstates: stat.set_clim([-z_max/10, z_max/10])
|
1094 |
+
|
1095 |
+
eu_fig.tight_layout(); eu_fig.subplots_adjust(wspace=0); eu_fig.subplots_adjust(hspace=0)
|
1096 |
+
ez_fig.tight_layout(); ez_fig.subplots_adjust(wspace=0); ez_fig.subplots_adjust(hspace=0)
|
1097 |
+
su_fig.tight_layout(); su_fig.subplots_adjust(wspace=0); su_fig.subplots_adjust(hspace=0)
|
1098 |
+
sz_fig.tight_layout(); sz_fig.subplots_adjust(wspace=0); sz_fig.subplots_adjust(hspace=0)
|
1099 |
+
|
1100 |
+
eu_fig.savefig(seu_path)
|
1101 |
+
ez_fig.savefig(sez_path)
|
1102 |
+
su_fig.savefig(ssu_path)
|
1103 |
+
sz_fig.savefig(ssz_path)
|
1104 |
+
plt.close('all')
|
1105 |
+
plt.clf()
|
1106 |
+
|
1107 |
+
for i, factor in enumerate(factors):
|
1108 |
+
fstr = f"{factor:.1f}".replace('.', '_')
|
1109 |
+
# wave_path = f'test/wave/rde.wav'
|
1110 |
+
we_path = wave_path.replace('rde.wav', f'rde-pinn-{fstr}.wav')
|
1111 |
+
ws_path = wave_path.replace('rde.wav', f'rde-fdtd-{fstr}.wav')
|
1112 |
+
est_wav = est["wav"][i].squeeze()
|
1113 |
+
sim_wav = sim["wav"][i].squeeze()
|
1114 |
+
sf.write(we_path, est_wav, samplerate=sr)
|
1115 |
+
sf.write(ws_path, sim_wav, samplerate=sr)
|
1116 |
+
|
1117 |
+
d = [ wandb.Image(magfig) ]; columns = ["logmag"]
|
1118 |
+
d += [ wandb.Image(melfig) ]; columns += ["logmel"]
|
1119 |
+
d += [ wandb.Image(eu_fig) ]; columns += ["PINN-u"]
|
1120 |
+
d += [ wandb.Image(su_fig) ]; columns += ["FDTD-u"]
|
1121 |
+
d += [ wandb.Image(ez_fig) ]; columns += ["PINN-z"]
|
1122 |
+
d += [ wandb.Image(sz_fig) ]; columns += ["FDTD-z"]
|
1123 |
+
d += [ wandb.Audio(est_wav, sample_rate=sr) ]; columns += ["PINN wav"]
|
1124 |
+
d += [ wandb.Audio(sim_wav, sample_rate=sr) ]; columns += ["FDTD wav"]
|
1125 |
+
data.append(d)
|
1126 |
+
|
1127 |
+
return {
|
1128 |
+
"columns": columns,
|
1129 |
+
"data": data,
|
1130 |
+
}
|
1131 |
+
|
1132 |
+
|