szin94 commited on
Commit
bc3e180
1 Parent(s): 14500e9

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpt/dmsp.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .*.swp
2
+ .*.py.swp
3
+ */.*.py.swp
4
+ */*/.*.py.swp
5
+ */*/*/.*.py.swp
6
+ __pycache__/
7
+ */__pycache__/
8
+ */*/__pycache__/
9
+ */*/*/__pycache__/
10
+
11
+ src/*.py
12
+ src/configs
13
+ src/dataset
14
+ src/task
15
+ src/model/cpp
16
+ src/model/*.py
17
+ check.py
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Dmsp
3
  emoji: 😻
4
  colorFrom: indigo
5
  colorTo: red
 
1
  ---
2
+ title: dmsp
3
  emoji: 😻
4
  colorFrom: indigo
5
  colorTo: red
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import yaml
4
+ import torch
5
+ import __main__
6
+ import numpy as np
7
+ import soundfile as sf
8
+ import librosa
9
+ import librosa.display
10
+ import matplotlib.pyplot as plt
11
+
12
+ import gradio as gr
13
+
14
+ from src.model.nn.synthesizer import Synthesizer
15
+ from src.utils.misc import triangular, downsample
16
+ from src.utils.plot import state_video as plot_state_video
17
+ from src.utils.audio import mel_basis, state_to_wav
18
+ from src.utils.control import vibrato as control_vibrato
19
+
20
+ class ConfigArgument:
21
+ def __getitem__(self,key):
22
+ return getattr(self, key)
23
+ def __setitem__(self,key,value):
24
+ return setattr(self, key, value)
25
+ setattr(__main__, "ConfigArgument", ConfigArgument)
26
+
27
+ def filter_state_dict(ckpt):
28
+ out_dict = {}
29
+ for key in ckpt.keys():
30
+ new_key = key[6:] if str(key)[:6] == 'model.' else key
31
+ out_dict[new_key] = ckpt[key]
32
+ return out_dict
33
+
34
+ def flush(directory):
35
+ os.makedirs(directory, exist_ok=True)
36
+ files = glob.glob(f'{directory}/*')
37
+ for f in files:
38
+ os.remove(f)
39
+
40
+ def add_glissando(f_0, Nt, sr, glissando, max_t):
41
+ front = int(0.2 * np.random.rand() * sr * max_t)
42
+ rear = int((0.2 * np.random.rand() + 0.3) * sr * max_t)
43
+ middle = max(0, len(f_0) - front - rear)
44
+ ramp = glissando * torch.cat((torch.zeros(front), torch.linspace(0,1,middle), torch.ones(rear)), dim=-1)
45
+ return f_0 * (1 + ramp)
46
+
47
+ def plot_spectrogram(path, x, n_fft=2048, hop_length=512, n_mel=256, samplerate=48000, max_duration=1):
48
+ x_wave = np.zeros(int(max_duration * samplerate))
49
+ x_wave[:len(x)] += x
50
+ x_spec = librosa.stft(
51
+ x_wave, n_fft=n_fft, hop_length=hop_length, win_length=n_fft, pad_mode='reflect')
52
+ mag = np.abs(x_spec) # (n_frames, n_freq)
53
+ mel_fbank = mel_basis(samplerate, n_fft, n_mel) # (n_mel, n_freq)
54
+ mel = np.einsum('ij,jk->ik', mel_fbank, mag) # (n_frames, n_mel)
55
+
56
+ plt.figure(figsize=(7,7))
57
+ librosa.display.specshow(mel)
58
+ plt.xticks([])
59
+ plt.yticks([])
60
+ plt.clim([0, 30])
61
+ plt.tight_layout()
62
+ plt.savefig(path, transparent=True)
63
+ plt.close('all')
64
+ plt.clf()
65
+
66
+ with open("ckpt/config.yaml") as stream:
67
+ configs = yaml.safe_load(stream)
68
+
69
+ with open("ckpt/pitch.yaml") as stream:
70
+ pitch_dict = yaml.safe_load(stream)
71
+
72
+ def get_data(duration, resolution, note, glissando, vibrato, stiffness, tension, pluck, amplitude):
73
+ sr = configs['sr']
74
+ Nt = int(duration * sr)
75
+ Nx = int(resolution)
76
+
77
+ xgrid = torch.linspace(0,1,Nx)
78
+ tgrid = torch.arange(Nt) / sr
79
+ pitch = pitch_dict[note]
80
+
81
+ t60_min_1=20.; t60_max_1=30.; t60_min_2=30.; t60_max_2=30.
82
+ t60_diff_max=5.
83
+ T60 = torch.Tensor([[[1000., 25.],[100., 30.]]])
84
+
85
+ Nw = int(Nt / configs['block_size']) + 1
86
+
87
+ xg, tg = torch.meshgrid(xgrid, tgrid, indexing='ij')
88
+ ka = torch.Tensor([stiffness]).view(-1,1) # (1,1)
89
+ al = torch.Tensor([tension]).view(-1,1) # (1,1)
90
+ f_0 = torch.ones(Nt) * pitch # (Nt)
91
+ nx = torch.Tensor([[[Nx]]]).float()
92
+ p_x = torch.ones_like(nx) * pluck
93
+ p_a = torch.ones_like(nx) * amplitude
94
+ u_0 = triangular(Nx, nx, p_x, p_a) # (1, 1, Nx)
95
+
96
+ f_0 = add_glissando(f_0, Nt, sr, glissando, Nt / sr)
97
+ f_0 = f_0 + control_vibrato(f_0.view(1,-1), 1/sr, mf=[3.,5.], ma=vibrato)
98
+ f_0 = downsample(f_0, factor=configs['block_size'])
99
+
100
+ xg = xg[:,0].view(-1,1) # (Nx, 1)
101
+ tg = tg # (Nx, Nt)
102
+ ka = ka.repeat(Nx,1) # (Nx, 1)
103
+ al = al.repeat(Nx,1) # (Nx, 1)
104
+ T60 = T60 # (Nx, 1, 1)
105
+ f_0 = f_0.repeat(Nx,1) # (Nx, Nw)
106
+ u_0 = u_0.repeat(Nx,1,1) # (Nx, 1, Nx)
107
+
108
+ params = [xg, tg, ka, al, T60, None, None]
109
+ return params, f_0, u_0
110
+
111
+ def run(duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude):
112
+ checkpoint = torch.load('ckpt/dmsp.ckpt', map_location='cpu')
113
+ checkpoint = filter_state_dict(checkpoint['state_dict'])
114
+ model = Synthesizer(**configs)
115
+ model.load_state_dict(checkpoint)
116
+
117
+ params, f_0, u_0 = get_data( \
118
+ duration, resolution, pitch, glissando, vibrato, stiffness, tension, pluck, amplitude)
119
+
120
+ with torch.no_grad():
121
+ ut, mode_input, mode_output = model(params, f_0, u_0)
122
+ ut = ut.detach() # (Nx, Nt)
123
+ ut_wave = configs['gain'] * ut.mean(0)
124
+
125
+ save_dir = 'results'
126
+ prefix = 'dmsp'
127
+ fname = 'output'
128
+ flush(save_dir)
129
+ audio_name = f'{save_dir}/{fname}.wav'
130
+ video_name = f'{save_dir}/{prefix}-{fname}.mp4'
131
+ spec_name = f'{save_dir}/spec.png'
132
+
133
+ ut = ut.numpy().T
134
+ ut_wave = ut_wave.numpy()
135
+ maxy = 0.022
136
+ sf.write(audio_name, ut_wave, samplerate=configs['sr'])
137
+ plot_spectrogram(spec_name, ut_wave, samplerate=configs['sr'])
138
+ plot_state_video(save_dir, ut, configs['sr'], prefix=prefix, fname=fname, maxy=maxy)
139
+ return spec_name, video_name
140
+
141
+ pitch_list = ["G2", "Ab2", "A2", "Bb2", "B2", "C3", "Db3", "D3", "Eb3", "E3", "F3", "Gb3", "G3", "Ab3", "A3", "Bb3", "B3", "C4", "Db4", "D4", "Eb4", "E4", "F4", "Gb4", "G4",]
142
+
143
+ duration = gr.Slider(0.1, 1.0, value=1.0, label="Time Duration")
144
+ resolution = gr.Slider(128, 256, value=256, label="Space Resolution", info='Reduce to simulate faster. Recommended to leave it as 256.')
145
+ pitch = gr.Dropdown(pitch_list, value="C3", label="Pitch", info="Specify the fundamental frequency as a musical note.")
146
+ glissando = gr.Slider(-0.4, 0.4, value=0, label="Glissando", info='Set +/- to ascend (+) or descend (-) the pitch')
147
+ vibrato = gr.Slider(0, 0.25, value=0, label="Vibrato", info='Set larger value to add more vibrato')
148
+ stiffness = gr.Slider(0.011, 0.029, value=0.02, label="Stiffness", info='Stiffness can change the resulting pitch. Specify low values when tension is high')
149
+ tension = gr.Slider(1.0, 25, value=4, label="Tension", info='Tension can introduce non-linear effects such as pitch glide. Specify low values when stiffness is high')
150
+ pluck = gr.Slider(0.12, 0.5, value=0.2, label="Pluck Position", info='Peak position of an initial condition')
151
+ amplitude = gr.Slider(0.001, 0.02, value=0.015, label="Pluck Amplitude", info='Peak amplitude of an initial condition')
152
+
153
+ demo = gr.Interface(
154
+ fn=run,
155
+ inputs=[
156
+ duration, resolution, pitch, glissando, vibrato,
157
+ stiffness, tension, pluck, amplitude,
158
+ ],
159
+ outputs=[
160
+ gr.Image(),
161
+ gr.Video(format='mp4', include_audio=True),
162
+ ],
163
+ )
164
+ demo.launch()
165
+
166
+
ckpt/config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ n_modes: 40
2
+ n_bands: 65
3
+ embed_dim: 128
4
+ use_precomputed_mod: False
5
+ harmonic: 'inharmonic'
6
+ hidden_dim: 512
7
+ block_size: 256
8
+ sr: 48000
9
+ gain: 100
10
+ x_scale: [0., 1.]
11
+ t_scale: [0., .3]
12
+ gamma_scale: [196, 880]
13
+ kappa_scale: [.01, .03]
14
+ alpha_scale: [1., 30.]
15
+ sig_0_scale: [0., 0.7]
16
+ sig_1_scale: [0., 0.00001]
ckpt/dmsp.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:071f2f176f1ced82e287ee5d41ffbe36dc6f613e7e5b1dea4e19e9609f82655b
3
+ size 104471835
ckpt/pitch.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ G2 : 98.00
2
+ Ab2: 103.83
3
+ A2 : 110.00
4
+ Bb2: 116.54
5
+ B2 : 123.47
6
+ C3 : 130.81
7
+ Db3: 138.59
8
+ D3 : 146.83
9
+ Eb3: 155.56
10
+ E3 : 164.81
11
+ F3 : 174.61
12
+ Gb3: 185.00
13
+ G3 : 196.00
14
+ Ab3: 207.65
15
+ A3 : 220.00
16
+ Bb3: 233.08
17
+ B3 : 246.94
18
+ C4 : 261.63
19
+ Db4: 277.18
20
+ D4 : 293.66
21
+ Eb4: 311.13
22
+ E4 : 329.63
23
+ F4 : 349.23
24
+ Gb4: 369.99
25
+ G4 : 392.00
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch==1.12.1
2
+ numpy==1.24.3
3
+ einops
4
+ librosa
5
+ omegaconf
src/model/nn/blocks.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from einops import rearrange
7
+
8
+ from src.utils import misc as utils
9
+
10
+ def apply_gain(x, gain, fn=None):
11
+ gain = fn(gain) if fn is not None else gain
12
+ x_list = x.chunk(len(gain), -1)
13
+ x_list = [gain[i] * x_i for i, x_i in enumerate(x_list)]
14
+ return torch.cat(x_list, dim=-1)
15
+
16
+ class FMBlock(nn.Module):
17
+ def __init__(self, input_dim, embed_dim, num_features):
18
+ super().__init__()
19
+ concat_size = embed_dim * num_features + embed_dim
20
+ feature_dim = embed_dim * num_features
21
+ self.rff2 = RFF2(input_dim, embed_dim//2)
22
+ self.tmlp = mlp(concat_size, feature_dim, 5)
23
+ self.proj = nn.Linear(concat_size, 2*input_dim)
24
+ self.activation = nn.GLU(dim=-1)
25
+
26
+ gain_in = torch.randn(num_features) / 2
27
+ gain_out = torch.Tensor([0.1])
28
+ self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True))
29
+ self.register_parameter('gain_out', nn.Parameter(gain_out, requires_grad=True))
30
+
31
+ def forward(self, input, feature, slider, omega):
32
+ ''' input : (B T input_dim)
33
+ feature: (B T feature_dim)
34
+ slider : (B T 1)
35
+ '''
36
+ _input = input / (1.3*math.pi) - 1
37
+ _input = self.rff2(_input)
38
+ feature = apply_gain(feature, self.gain_in, torch.tanh)
39
+
40
+ x = torch.cat((_input, feature), dim=-1)
41
+ x = torch.cat((self.tmlp(x), _input), dim=-1)
42
+ x = self.activation(self.proj(x))
43
+
44
+ gate = torch.tanh((slider - 1) * self.gain_out)
45
+ return input + omega * x * gate
46
+
47
+ class AMBlock(nn.Module):
48
+ def __init__(self, input_dim, embed_dim, num_features):
49
+ super().__init__()
50
+ concat_size = embed_dim * num_features + embed_dim
51
+ feature_dim = embed_dim * num_features
52
+ self.rff2 = RFF2(input_dim, embed_dim//2)
53
+ self.tmlp = mlp(concat_size, feature_dim, 5)
54
+ self.proj = nn.Linear(concat_size, 2*input_dim)
55
+ self.activation = nn.GLU(dim=-1)
56
+
57
+ gain_in = torch.randn(num_features) / 2
58
+ self.register_parameter('gain_in', nn.Parameter(gain_in, requires_grad=True))
59
+
60
+ def forward(self, input, feature, slider):
61
+ ''' input : (B T input_dim)
62
+ feature: (B T feature_dim)
63
+ slider : (B T 1)
64
+ '''
65
+ _input = input * 110 - 0.55
66
+ _input = self.rff2(_input)
67
+ feature = apply_gain(feature, self.gain_in, torch.tanh)
68
+
69
+ x = torch.cat((_input, feature), dim=-1)
70
+ x = torch.cat((self.tmlp(x), _input), dim=-1)
71
+ x = self.activation(self.proj(x))
72
+
73
+ return input * (1 + x)
74
+
75
+ class ModBlock(nn.Module):
76
+ def __init__(self, input_dim, feature_dim, embed_dim):
77
+ super().__init__()
78
+ cat_size = 1+feature_dim
79
+ self.tmlp = mlp(cat_size, feature_dim, 2)
80
+ self.proj = nn.Linear(cat_size, 2)
81
+ self.activation = nn.GLU(dim=-1)
82
+
83
+ def forward(self, input, feature, slider):
84
+ ''' input : (B T input_dim)
85
+ feature: (B T feature_dim)
86
+ slider : (B T 1)
87
+ '''
88
+ input = input.unsqueeze(-1) # (B T input_dim 1)
89
+ feature = feature.unsqueeze(-2).repeat(1,1,input.size(-2),1)
90
+ x = torch.cat((input, feature), dim=-1)
91
+ x = torch.cat((self.tmlp(x), input), dim=-1)
92
+ x = self.activation(self.proj(x))
93
+ return (input * (1 + x)).squeeze(-1)
94
+
95
+ def mlp(in_size, hidden_size, n_layers):
96
+ channels = [in_size] + (n_layers) * [hidden_size]
97
+ net = []
98
+ for i in range(n_layers):
99
+ net.append(nn.Linear(channels[i], channels[i + 1]))
100
+ #net.append(nn.LayerNorm(channels[i + 1]))
101
+ net.append(nn.PReLU())
102
+ return nn.Sequential(*net)
103
+
104
+
105
+ class RFF2(nn.Module):
106
+ """ Random Fourier Features Module """
107
+ def __init__(self, input_dim, embed_dim, scale=1.):
108
+ super().__init__()
109
+ #N = torch.randn(input_dim, embed_dim)
110
+ N = torch.ones((input_dim, embed_dim)) / input_dim / embed_dim
111
+ N = nn.Parameter(N, requires_grad=False)
112
+ e = torch.Tensor([scale])
113
+ e = nn.Parameter(e, requires_grad=True)
114
+ self.register_buffer('N', N)
115
+ self.register_parameter('e', e)
116
+
117
+ def forward(self, x):
118
+ ''' x: (Bs, Nt, input_dim)
119
+ -> (Bs, Nt, embed_dim)
120
+ '''
121
+ B = self.e * self.N
122
+ x_embd = utils.fourier_feature(x, B)
123
+ return x_embd
124
+
125
+
126
+ class RFF(nn.Module):
127
+ """ Random Fourier Features Module """
128
+ def __init__(self, scales, embed_dim):
129
+ super().__init__()
130
+ input_dim = len(scales)
131
+ N = torch.randn(input_dim, embed_dim)
132
+ N = nn.Parameter(N, requires_grad=False)
133
+ e = torch.Tensor(scales).view(-1,1)
134
+ e = nn.Parameter(e, requires_grad=True)
135
+ self.register_buffer('N', N)
136
+ self.register_parameter('e', e)
137
+
138
+ def forward(self, x):
139
+ ''' x: (Bs, Nt, input_dim)
140
+ -> (Bs, Nt, input_dim*embed_dim)
141
+ '''
142
+ xs = x.chunk(self.N.size(0), -1) # (Bs, Nt, 1) * input_dim
143
+ Ns = self.N.chunk(self.N.size(0), 0) # (1, embed_dim) * input_dim
144
+ Bs = [torch.pow(10, self.e[i]) * N for i, N in enumerate(Ns)]
145
+ x_embd = [utils.fourier_feature(xs[i], B) for i, B in enumerate(Bs)]
146
+ return torch.cat(x_embd, dim=-1)
147
+
148
+
149
+ class ModeEstimator(nn.Module):
150
+ def __init__(self, n_modes, hidden_dim, kappa_scale=None, gamma_scale=None, inharmonic=True, sr=48000):
151
+ super().__init__()
152
+ self.sr = sr
153
+ self.kappa_scale = kappa_scale
154
+ self.gamma_scale = gamma_scale
155
+ self.rff = RFF([1.]*5, hidden_dim//2)
156
+ self.a_mlp = mlp(5*hidden_dim, hidden_dim, 2)
157
+ self.a_proj = nn.Linear(hidden_dim, n_modes)
158
+ self.tanh = nn.Tanh()
159
+ if inharmonic:
160
+ self.f_mlp = mlp(5*hidden_dim, hidden_dim, 2)
161
+ self.f_proj = nn.Linear(hidden_dim, n_modes)
162
+ self.sigmoid = nn.Sigmoid()
163
+ else:
164
+ self.f_mlp = None
165
+ self.f_proj = None
166
+ self.sigmoid = nn.Sigmoid()
167
+
168
+ def forward(self, u_0, x_p, kappa, gamma):
169
+ ''' u_0 : (b, 1, x)
170
+ x_p : (b, 1, 1)
171
+ kappa : (b, 1, 1)
172
+ gamma : (b, 1, 1)
173
+ '''
174
+ p_x = torch.argmax(u_0, dim=-1, keepdim=True) / 255. # (b, 1, 1)
175
+ p_a = torch.max(u_0, dim=-1, keepdim=True).values / 0.02 # (b, 1, 1)
176
+ kappa = self.normalize_kappa(kappa)
177
+ gamma = self.normalize_gamma(gamma)
178
+ con = torch.cat((p_x, p_a, x_p, kappa, gamma), dim=-1) # (b, 1, 5)
179
+ con = self.rff(con) # (b, 1, 3*hidden_dim)
180
+
181
+ mode_amps = self.a_mlp(con) # (b, 1, k)
182
+ mode_amps = self.tanh(1e-3 * self.a_proj(mode_amps)) # (b, 1, m)
183
+
184
+ if self.f_mlp is not None:
185
+ mode_freq = self.f_mlp(con) # (b, 1, k)
186
+ mode_freq = 0.3 * self.sigmoid(self.f_proj(mode_freq)) # (b, 1, m)
187
+ mode_freq = mode_freq.cumsum(-1)
188
+ else:
189
+ int_mults = torch.ones_like(mode_amps).cumsum(-1) # (b, 1, k)
190
+ omega = gamma / self.sr * (2*math.pi)
191
+ mode_freq = omega * int_mults
192
+
193
+ return mode_amps, mode_freq
194
+
195
+ def normalize_gamma(self, x):
196
+ if self.gamma_scale is not None:
197
+ minval = min(self.gamma_scale)
198
+ denval = max(self.gamma_scale) - minval
199
+ x = (x - minval) / denval
200
+ return x
201
+
202
+ def normalize_kappa(self, x):
203
+ if self.kappa_scale is not None:
204
+ minval = min(self.kappa_scale)
205
+ denval = max(self.kappa_scale) - minval
206
+ x = (x - minval) / denval
207
+ return x
208
+
src/model/nn/ddsp.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.model.nn.blocks import FMBlock, AMBlock
4
+ from src.utils.ddsp import upsample
5
+ from src.utils.ddsp import remove_above_nyquist_mode
6
+ from src.utils.ddsp import amp_to_impulse_response, fft_convolve
7
+ from src.utils.ddsp import modal_synth
8
+ from src.utils.ddsp import resample
9
+ import math
10
+
11
+ class DDSP(nn.Module):
12
+ def __init__(self,
13
+ feature_size, hidden_size,
14
+ n_modes, n_bands, sampling_rate, block_size,
15
+ fm=False,
16
+ ):
17
+ super().__init__()
18
+ self.n_modes = n_modes
19
+
20
+ self.freq_modulator = FMBlock(n_modes, feature_size) if fm else None
21
+ self.coef_modulator = AMBlock(n_modes, feature_size)
22
+ self.noise_proj = nn.Linear(feature_size, n_bands)
23
+
24
+ noise_gate = nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
25
+ self.register_parameter("noise_gate", noise_gate)
26
+ self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
27
+ self.register_buffer("block_size", torch.tensor(block_size))
28
+
29
+ def forward(self, hidden, mode_freq, mode_coef, times, alpha, lengths):
30
+ ''' hidden : (Bs, 1, hidden_size)
31
+ mode_freq : (Bs, Nt, n_modes)
32
+ mode_coef : (Bs, 1, n_modes)
33
+ times : (Bs, Nt, 1)
34
+ '''
35
+ if self.freq_modulator is None:
36
+ freq_m = mode_freq # integer multiples
37
+ else:
38
+ freq_m = self.freq_modulator(mode_freq, hidden)
39
+ coef_m = self.coef_modulator(mode_coef, hidden, times)
40
+
41
+ #==============================
42
+ # harmonic part
43
+ #==============================
44
+ freqs = freq_m / (2*math.pi) * self.sampling_rate
45
+ coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes)
46
+ freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths)
47
+ coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths)
48
+ harmonic = modal_synth(freq_s, coef_s, self.sampling_rate)
49
+
50
+ #==============================
51
+ # noise part
52
+ #==============================
53
+ ngate = torch.tanh((alpha - 1) * self.noise_gate)
54
+ param = ngate * torch.sigmoid(self.noise_proj(hidden) - 5)
55
+
56
+ impulse = amp_to_impulse_response(param, self.block_size)
57
+ noise = torch.rand(
58
+ impulse.shape[0],
59
+ impulse.shape[1],
60
+ self.block_size,
61
+ ).to(impulse) * 2 - 1
62
+ noise = fft_convolve(noise, impulse).contiguous()
63
+ noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths)
64
+
65
+ signal = harmonic + noise
66
+ return signal.squeeze(-1), freq_m, coef_m
67
+
68
+
69
+
src/model/nn/dmsp.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from src.model.nn.blocks import FMBlock, AMBlock, ModBlock
4
+ from src.utils.ddsp import scale_function, remove_above_nyquist, upsample
5
+ from src.utils.ddsp import remove_above_nyquist_mode
6
+ from src.utils.ddsp import harmonic_synth, amp_to_impulse_response, fft_convolve
7
+ from src.utils.ddsp import modal_synth
8
+ from src.utils.ddsp import resample
9
+ import math
10
+
11
+ class DMSP(nn.Module):
12
+ def __init__(self,
13
+ embed_dim, hidden_size, n_features,
14
+ n_modes, n_bands, sampling_rate, block_size,
15
+ ):
16
+ super().__init__()
17
+ self.n_modes = n_modes
18
+
19
+ self.freq_modulator = FMBlock(n_modes, embed_dim, n_features)
20
+ self.coef_modulator = AMBlock(n_modes, embed_dim, n_features)
21
+ self.proj_noise = nn.Linear(n_features*embed_dim, n_bands)
22
+
23
+ self.register_buffer("sampling_rate", torch.tensor(sampling_rate))
24
+ self.register_buffer("block_size", torch.tensor(block_size))
25
+
26
+ def forward(self, hidden, mode_freq, mode_coef, times, alpha, omega, lengths):
27
+ ''' hidden : (Bs, 1, hidden_size)
28
+ mode_freq : (Bs, Nt, n_modes)
29
+ mode_coef : (Bs, 1, n_modes)
30
+ times : (Bs, Nt, 1)
31
+ '''
32
+ freq_m = self.freq_modulator(mode_freq, hidden, alpha, omega)
33
+ coef_m = self.coef_modulator(mode_coef, hidden, times)
34
+
35
+ #==============================
36
+ # harmonic part
37
+ #==============================
38
+ freqs = freq_m / (2*math.pi) * self.sampling_rate
39
+ coef_m = remove_above_nyquist_mode(coef_m, freqs, self.sampling_rate) # (Bs, Nt, n_modes)
40
+ freq_s = upsample(freq_m, self.block_size).narrow(1,0,lengths)
41
+ coef_s = upsample(coef_m, self.block_size).narrow(1,0,lengths)
42
+ harmonic = modal_synth(freq_s, coef_s, self.sampling_rate)
43
+
44
+ #==============================
45
+ # noise part
46
+ #==============================
47
+ param = scale_function(self.proj_noise(hidden) - 5)
48
+
49
+ impulse = amp_to_impulse_response(param, self.block_size)
50
+ noise = torch.rand(
51
+ impulse.shape[0],
52
+ impulse.shape[1],
53
+ self.block_size,
54
+ ).to(impulse) * 2 - 1
55
+ noise = fft_convolve(noise, impulse).contiguous()
56
+ noise = noise.reshape(noise.shape[0], -1, 1).narrow(1,0,lengths)
57
+
58
+ signal = harmonic + noise
59
+ return signal.squeeze(-1), freq_m, coef_m
60
+
61
+
62
+
63
+
src/model/nn/synthesizer.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange
6
+
7
+ from src.utils import audio as audio
8
+
9
+ class Synthesizer(nn.Module):
10
+ """ Synthesizer Network """
11
+ def __init__(self,
12
+ embed_dim=64,
13
+ x_scale=1, t_scale=1,
14
+ gamma_scale=0, kappa_scale=0, alpha_scale=0, sig_0_scale=0, sig_1_scale=0,
15
+ **kwargs):
16
+ super().__init__()
17
+ self.sr=kwargs['sr']
18
+ hidden_dim=kwargs['hidden_dim']
19
+ self.n_modes = kwargs['n_modes']
20
+ inharmonic = kwargs['harmonic'].lower() == 'inharmonic'
21
+
22
+ self.x_scale = x_scale
23
+ self.t_scale = t_scale
24
+ self.gamma_scale = gamma_scale
25
+ self.kappa_scale = kappa_scale
26
+ self.alpha_scale = alpha_scale
27
+ self.sig_0_scale = sig_0_scale
28
+ self.sig_1_scale = sig_1_scale
29
+
30
+ from src.model.nn.blocks import RFF, ModeEstimator
31
+ n_feats = 7
32
+ self.material_encoder = RFF([1.]*n_feats, embed_dim // 2)
33
+ feature_size = embed_dim * n_feats
34
+ self.mode_estimator = ModeEstimator(
35
+ self.n_modes, embed_dim, kappa_scale, gamma_scale,
36
+ inharmonic=inharmonic,
37
+ )
38
+ if inharmonic:
39
+ from src.model.nn.dmsp import DMSP
40
+ self.net = DMSP(
41
+ embed_dim=embed_dim,
42
+ hidden_size=hidden_dim,
43
+ n_features=n_feats,
44
+ n_modes=kwargs['n_modes'],
45
+ n_bands=kwargs['n_bands'],
46
+ block_size=kwargs['block_size'],
47
+ sampling_rate=kwargs['sr'],
48
+ )
49
+ else:
50
+ from src.model.nn.ddsp import DDSP
51
+ self.net = DDSP(
52
+ feature_size=feature_size,
53
+ hidden_size=hidden_dim,
54
+ n_modes=kwargs['n_modes'],
55
+ n_bands=kwargs['n_bands'],
56
+ block_size=kwargs['block_size'],
57
+ sampling_rate=kwargs['sr'],
58
+ fm=kwargs['ddsp_frequency_modulation'],
59
+ )
60
+
61
+ def forward(self, params, pitch, initial):
62
+ ''' params : input parameters
63
+ pitch : fundamental frequency in Hz
64
+ initial: initial condition
65
+ '''
66
+ space, times, kappa, alpha, t60, mode_freq, mode_coef = params
67
+
68
+ f_0 = pitch.unsqueeze(2) # (b, frames, 1)
69
+ times = times.unsqueeze(-1) # (b, sample, 1)
70
+ kappa = kappa.unsqueeze(-1) # (b, 1, 1)
71
+ alpha = alpha.unsqueeze(-1) # (b, 1, 1)
72
+ space = space.unsqueeze(-1) # (b, 1, 1)
73
+ gamma = 2*f_0 # (b, frames, 1)
74
+ omega = f_0 / self.sr * (2*math.pi) # (b, t, 1)
75
+ relf0 = omega - omega.narrow(1,0,1) # (b, t, 1)
76
+
77
+ in_coef, in_freq = self.mode_estimator(initial, space, kappa, gamma.narrow(1,9,1))
78
+ mode_coef = in_coef if mode_coef is None else mode_coef
79
+ mode_freq = in_freq if mode_freq is None else mode_freq
80
+ mode_freq = mode_freq + relf0 # linear FM
81
+
82
+ Nt = times.size(1) # total number of samples
83
+ Nf = mode_freq.size(1) # total number of frames
84
+ frames = self.get_frame_time(times, Nf)
85
+
86
+ space = space.repeat(1,f_0.size(1),1) # (b, frames, 1)
87
+ alpha = alpha.repeat(1,f_0.size(1),1) # (b, frames, 1)
88
+ kappa = kappa.repeat(1,f_0.size(1),1) # (b, frames, 1)
89
+ sigma = audio.T60_to_sigma(t60, f_0, 2*f_0*kappa) # (b, frames, 2)
90
+
91
+ # fourier features
92
+ feat = [space, frames, kappa, alpha, sigma, gamma]
93
+ feat = self.normalize_params(feat)
94
+ feat = self.material_encoder(feat) # (b, frames, n_feats * embed_dim)
95
+
96
+ damping = torch.exp(- frames * sigma.narrow(-1,0,1))
97
+ mode_coef = mode_coef * damping
98
+ ut, ut_freq, ut_coef = self.net(feat, mode_freq, mode_coef, frames, alpha, omega, Nt)
99
+ return ut, [in_freq, in_coef], [ut_freq, ut_coef]
100
+
101
+ def get_frame_time(self, times, Nf):
102
+ t_0 = times.narrow(1,0,1) # (Bs, 1, 1)
103
+ t_k = torch.ones_like(t_0).repeat(1,Nf,1).cumsum(1) / self.sr
104
+ t_k = t_k + t_0 # (Bs, Nt, 1)
105
+ return t_k
106
+
107
+ def normalize_params(self, params):
108
+ def rescale(var, scale):
109
+ minval = min(scale)
110
+ denval = max(scale) - minval
111
+ return (var - minval) / denval
112
+ space, times, kappa, alpha, sigma, gamma = params
113
+ sig_0, sig_1 = sigma.chunk(2, -1)
114
+ space = rescale(space, self.x_scale)
115
+ times = rescale(times - max(self.t_scale), self.t_scale)
116
+ kappa = rescale(kappa, self.kappa_scale)
117
+ alpha = rescale(alpha, self.alpha_scale)
118
+ sig_0 = rescale(sig_0, self.sig_0_scale)
119
+ sig_1 = rescale(sig_1, self.sig_1_scale)
120
+ gamma = rescale(gamma, self.gamma_scale)
121
+ sigma = torch.cat((sig_0, sig_1), dim=-1)
122
+ return torch.cat([space, times, kappa, alpha, sigma, gamma], dim=-1)
123
+
124
+
125
+
src/utils/audio.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import librosa
6
+ import soundfile as sf
7
+ from einops import rearrange
8
+
9
+ eps = np.finfo(np.float32).eps
10
+
11
+ def calculate_rms(amp):
12
+ if isinstance(amp, torch.Tensor):
13
+ return amp.pow(2).mean(-1, keepdim=True).pow(.5)
14
+ elif isinstance(amp, np.ndarray):
15
+ return np.sqrt(np.mean(np.square(amp), axis=-1) + eps)
16
+ else:
17
+ raise TypeError(f"argument 'amp' must be torch.Tensor or np.ndarray. got: {type(amp)}")
18
+
19
+ def dB2amp(dB):
20
+ return np.power(10., dB/20.)
21
+
22
+ def amp2dB(amp):
23
+ return 20. * np.log10(amp)
24
+
25
+ def rms_normalize(wav, ref_dBFS=-23.0, skip_nan=True):
26
+ exists_nan = np.isnan(np.sum(wav))
27
+ if not skip_nan:
28
+ assert not exists_nan, np.isnan(wav)
29
+ if exists_nan:
30
+ return wav, 1.
31
+ # RMS normalize
32
+ # value_dBFS = 20*log10(rms(signal) * sqrt(2)) = 20*log10(rms(signal)) + 3.0103
33
+ rms = calculate_rms(wav)
34
+ if isinstance(ref_dBFS, torch.Tensor):
35
+ ref_linear = torch.pow(10, (ref_dBFS-3.0103)/20.)
36
+ else:
37
+ ref_linear = np.power(10, (ref_dBFS-3.0103)/20.)
38
+ gain = ref_linear / (rms + eps)
39
+ wav = gain * wav
40
+ return wav, gain
41
+
42
+ def ell_infty_normalize(wav, skip_nan=True):
43
+ if isinstance(wav, np.ndarray):
44
+ ''' numpy '''
45
+ exists_nan = np.isnan(np.sum(wav))
46
+ if not skip_nan:
47
+ assert not exists_nan, np.isnan(wav)
48
+ if exists_nan:
49
+ return wav, 1.
50
+ maxv = np.max(np.abs(wav), axis=-1)
51
+ # 1 if maxv == 0 else 1. / maxv
52
+ if len(list(maxv.shape)) == 0:
53
+ gain = 1 if maxv==0 else 1. / maxv
54
+ else:
55
+ gain = 1. / maxv; gain[maxv==0] = 1
56
+ elif isinstance(wav, torch.Tensor):
57
+ ''' torch '''
58
+ exists_nan = torch.isnan(wav.sum())
59
+ if not skip_nan:
60
+ assert not exists_nan, torch.isnan(wav)
61
+ if exists_nan:
62
+ return wav, 1.
63
+ maxv = wav.abs().max(-1).values.unsqueeze(-1)
64
+ # 1 if maxv == 0 else 1. / maxv
65
+ gain = torch.where(maxv.eq(0),
66
+ torch.ones_like(maxv), 1. / maxv)
67
+ else:
68
+ assert False, wav
69
+ wav = gain * wav
70
+ return wav, gain
71
+
72
+ def dB_RMS(wav):
73
+ if isinstance(wav, torch.Tensor):
74
+ return 20 * torch.log10(calculate_rms(wav))
75
+ elif isinstance(wav, np.ndarray):
76
+ return 20 * np.log10(calculate_rms(wav))
77
+
78
+ def mel_basis(sr, n_fft, n_mel):
79
+ return librosa.filters.mel(sr=sr,n_fft=n_fft,n_mels=n_mel,fmin=0,fmax=sr//2,norm=1)
80
+
81
+ def inv_mel_basis(sr, n_fft, n_mel):
82
+ return librosa.filters.mel(
83
+ sr=sr, n_fft=n_fft, n_mels=n_mel, norm=None, fmin=0, fmax=sr//2,
84
+ ).T
85
+
86
+ def lin_to_mel(linspec, sr, n_fft, n_mel=80):
87
+ basis = mel_basis(sr, n_fft, n_mel)
88
+ return basis @ linspec
89
+
90
+ def save_waves(est, save_dir, sr=16000):
91
+ data = []
92
+ batch_size = inp.shape[0]
93
+ for b in range(batch_size):
94
+ est_wav = est[b,0].squeeze()
95
+ wave_path = f"{save_dir}/{b}.wav"
96
+ sf.write(wave_path, est_wav, samplerate=sr)
97
+
98
+ def get_inverse_window(forward_window, frame_length, frame_step):
99
+ denom = torch.square(forward_window)
100
+ overlaps = -(-frame_length // frame_step) # Ceiling division.
101
+ denom = F.pad(denom, (0, overlaps * frame_step - frame_length))
102
+ denom = denom.reshape(overlaps, frame_step)
103
+ denom = denom.sum(0, keepdims=True)
104
+ denom = denom.tile(overlaps, 1)
105
+ denom = denom.reshape(overlaps * frame_step)
106
+ return forward_window / denom[:frame_length]
107
+
108
+ def state_to_wav(state, normalize=True, sr=48000):
109
+ ''' state: (Bs, Nt, Nx) '''
110
+ assert len(list(state.shape)) == 3, state.shape
111
+ Nt = state.size(1)
112
+ vel = ((state.narrow(1,1,Nt-1) - state.narrow(1,0,Nt-1)) * sr).sum(-1)
113
+ return ell_infty_normalize(vel)[0] if normalize else vel
114
+
115
+ def state_to_spec(x, window):
116
+ ''' x: (Bs, Nt, Nx, Ch)
117
+ -> (Bs, Nt, Nx, Ch*n_fft*2)
118
+ '''
119
+ Bs, Nt, Nx, Ch = x.shape
120
+ n_ffts = window.size(-1)
121
+ n_freq = n_ffts // 2 + 1
122
+ hop_length = n_ffts // 4
123
+ x = rearrange(x, 'b t x c -> (b x c) t')
124
+ s = torch.stft(x, n_ffts, hop_length=hop_length, window=window)
125
+ s = rearrange(s, '(b x c) f t k -> b t x (c f k)',
126
+ b=Bs, x=Nx, c=Ch, f=n_freq, k=2)
127
+ return s
128
+
129
+ def spec_to_state(x, window, length):
130
+ ''' x: (Bs, Nt, Nx, Ch*n_fft*2)
131
+ -> (Bs, Nt, Nx, Ch)
132
+ '''
133
+ Bs, Nt, Nx, _ = x.shape
134
+ n_ffts = window.size(-1)
135
+ n_freq = n_ffts // 2 + 1
136
+
137
+ x = rearrange(x, 'b t x (c f k) -> (b x c) f t k', f=n_freq, k=2)
138
+ x = torch.istft(x, n_ffts, length=length, window=window)
139
+ x = rearrange(x, '(b x c) t -> b t x c', b=Bs, x=Nx)
140
+ return x
141
+
142
+
143
+ def to_spec(x, window, reduce_channel=True):
144
+ ''' x: (Bs, Nt)
145
+ -> (Bs, Nt, Nf*2) if reduce_channel==True
146
+ -> (Bs, Nt, Nf,2) otherwise
147
+ '''
148
+ Bs, Nt = x.shape
149
+ n_ffts = window.size(-1)
150
+ n_freq = n_ffts // 2 + 1
151
+ hop_length = n_ffts // 4
152
+ s = torch.stft(x, n_ffts, hop_length=hop_length, window=window)
153
+ s = s.transpose(1,2)
154
+ if reduce_channel:
155
+ s = rearrange(s, 'b t f k -> b t (f k)',
156
+ b=Bs, f=n_freq, k=2)
157
+ return s
158
+
159
+ def from_spec(x, window, length):
160
+ ''' x: (Bs, Nt, Nf*2)
161
+ -> (Bs, Nt)
162
+ '''
163
+ Bs, Nt, _ = x.shape
164
+ n_ffts = window.size(-1)
165
+ n_freq = n_ffts // 2 + 1
166
+
167
+ x = rearrange(x, 'b t (f k) -> b f t k', f=n_freq, k=2)
168
+ x = torch.istft(x, n_ffts, length=length, window=window)
169
+ return x
170
+
171
+ def adjust_gain(y, x, minmax, ref_dBFS=-23.0):
172
+ ran_gain = (minmax[1] - minmax[0]) * torch.rand_like(y.narrow(-1,0,1)) + minmax[0]
173
+ ref_linear = np.power(10, (ref_dBFS-3.0103)/20.)
174
+ ran_linear = torch.pow(10, (ran_gain-3.0103)/20.)
175
+ x_rms = calculate_rms(x)
176
+ y_rms = calculate_rms(y)
177
+ x_gain = ref_linear / (x_rms + eps)
178
+ y_gain = ref_linear / (y_rms + eps)
179
+
180
+ y_xscale = y * y_gain / x_gain
181
+ return y_xscale / ran_linear
182
+
183
+ def degrade(x, rir, noise):
184
+ ''' x : (Bs, Nt)
185
+ rir : (Bs, Nt)
186
+ noise: (Bs, Nt)
187
+ '''
188
+ x_pad = F.pad(x, (0,rir.size(-1)))
189
+ w_pad = F.pad(rir, (0,rir.size(-1)))
190
+ x_fft = torch.fft.rfft(x_pad)
191
+ w_fft = torch.fft.rfft(w_pad)
192
+ wet_x = torch.fft.irfft(x_fft * w_fft).narrow(-1,0,x.size(-1))
193
+
194
+ y = adjust_gain(wet_x, x, [-0, 30]) # ser
195
+ n = adjust_gain(noise, y, [10, 30]) # snr
196
+ return y + n
197
+
198
+ def T60_to_sigma(T60, f_0, K):
199
+ ''' T60 : (Bs, 2, 2) [[T60_freq_1, T60_1], [T60_freq_2, T60_2]]
200
+ f_0 : (Bs, Nt, 1) fundamental frequency
201
+ K : (Bs, Nt, 1) kappa (K == gamma * kappa_rel)
202
+ -> sig : (Bs, Nt, 2)
203
+ '''
204
+ gamma = f_0 * 2
205
+ freq1, time1 = T60.narrow(1,0,1).chunk(2,-1)
206
+ freq2, time2 = T60.narrow(1,1,1).chunk(2,-1)
207
+
208
+ zeta1 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq1).pow(2)).pow(.5)
209
+ zeta2 = - gamma.pow(2) + (gamma.pow(4) + 4 * K.pow(2) * (2 * math.pi * freq2).pow(2)).pow(.5)
210
+ sig0 = - zeta2 / time1 + zeta1 / time2
211
+ sig0 = 6 * math.log(10) * sig0 / (zeta1 - zeta2)
212
+
213
+ sig1 = 1 / time1 - 1 / time2
214
+ sig1 = 6 * math.log(10) * sig1 / (zeta1 - zeta2)
215
+
216
+ sig = torch.cat((sig0, sig1), dim=-1)
217
+ return sig
218
+
219
+
src/utils/control.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ def constant(f0, n, dtype=None):
6
+ ''' f0 (batch_size,)
7
+ n (int)
8
+ '''
9
+ return f0.unsqueeze(-1) * torch.ones(1,n, dtype=dtype)
10
+
11
+ def linear(f1, f2, n):
12
+ ''' f1 (batch_size,)
13
+ f2 (batch_size,)
14
+ n (int)
15
+ '''
16
+ out = torch.cat((f1.unsqueeze(-1),f2.unsqueeze(-1)), dim=-1) # (batch_size, 2)
17
+ out = F.interpolate(out.unsqueeze(1), size=n, mode='linear', align_corners=True).squeeze(1) # (batch_size, n)
18
+ return out
19
+
20
+ def glissando(f1, f2, n, mode='linear'):
21
+ if mode == 'linear':
22
+ return linear(f1, f2, n)
23
+ else:
24
+ raise NotImplementedError(mode)
25
+
26
+ def vibrato(f0, k, mf=[3,5], ma=0.05, ma_in_hz=False):
27
+ ''' f0 (batch_size, n)
28
+ k (int): 1/sr
29
+ mf (list): modulation frequency ([min, max])
30
+ ma (float): modulation amplitude (in Hz)
31
+ ma_in_hz (bool): ma is given in Hz (else: ma is given as a weighting factor of f0)
32
+ '''
33
+ ff = f0.narrow(-1,0,1)
34
+ def get_new_vibrato(f0, k, mf, ma, ma_in_hz):
35
+ mod_frq = mf[1] * torch.rand_like(ff) + mf[0] # (B, 1)
36
+ mod_amp = ma * torch.rand_like(ff) # (B, 1)
37
+
38
+ nt = f0.size(-1) # total time
39
+ vt = torch.floor((nt // 2) * torch.rand(f0.size(0)).view(-1,1)) # vibrato time
40
+ t = torch.ones_like(f0).cumsum(-1)
41
+ m = t.gt(vt) # mask `t` for n <= vt
42
+ vibra = m * mod_amp * (1 - torch.cos(2 * np.pi * mod_frq * (t - vt) * k)) / 2
43
+ if not ma_in_hz: vibra *= f0
44
+ return vibra * torch.randn_like(ff).sign()
45
+ return f0 + get_new_vibrato(f0, k, mf, ma, ma_in_hz)
46
+
47
+ def triangle_with_velocity(vel, n, sr_t, sr_x, max_u=.1):
48
+ ''' vel (batch_size,) velocity
49
+ n (int) number of samples
50
+ sr_t (int) sampling rate in time
51
+ sr_x (int) sampling rate in space
52
+ max_u (float) maximum displacement
53
+ '''
54
+ vel = vel.view(-1,1) * sr_x / sr_t # m/s to non-dimensional quantity
55
+ vel = vel * torch.ones_like(vel).repeat(1,n)
56
+ u_H = torch.relu(max_u - (max_u - vel.cumsum(1)).abs() - vel)
57
+ u_H = u_H.pow(5).clamp(max=0.01)
58
+ return u_H
59
+
60
+
61
+
src/utils/ddsp.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.fft as fft
4
+ import numpy as np
5
+ import librosa as li
6
+ import math
7
+
8
+ def safe_log(x):
9
+ return torch.log(x + 1e-7)
10
+
11
+ @torch.no_grad()
12
+ def mean_std_loudness(dataset):
13
+ mean = 0
14
+ std = 0
15
+ n = 0
16
+ for _, _, l in dataset:
17
+ n += 1
18
+ mean += (l.mean().item() - mean) / n
19
+ std += (l.std().item() - std) / n
20
+ return mean, std
21
+
22
+
23
+ def multiscale_fft(signal, scales, overlap):
24
+ stfts = []
25
+ for s in scales:
26
+ S = torch.stft(
27
+ signal,
28
+ s,
29
+ int(s * (1 - overlap)),
30
+ s,
31
+ torch.hann_window(s).to(signal),
32
+ True,
33
+ normalized=True,
34
+ return_complex=True,
35
+ ).abs()
36
+ stfts.append(S)
37
+ return stfts
38
+
39
+
40
+ def resample(x, factor: int):
41
+ batch, frame, channel = x.shape
42
+ x = x.permute(0, 2, 1).reshape(batch * channel, 1, frame)
43
+
44
+ window = torch.hann_window(
45
+ factor * 2,
46
+ dtype=x.dtype,
47
+ device=x.device,
48
+ ).reshape(1, 1, -1)
49
+ y = torch.zeros(x.shape[0], x.shape[1], factor * x.shape[2]).to(x)
50
+ y[..., ::factor] = x
51
+ y[..., -1:] = x[..., -1:]
52
+ y = torch.nn.functional.pad(y, [factor, factor])
53
+ y = torch.nn.functional.conv1d(y, window)[..., :-1]
54
+
55
+ y = y.reshape(batch, channel, factor * frame).permute(0, 2, 1)
56
+
57
+ return y
58
+
59
+
60
+
61
+ def upsample(signal, factor):
62
+ signal = signal.permute(0,2,1)
63
+ signal = nn.functional.interpolate(signal, size=signal.shape[-1] * factor, mode='linear')
64
+ return signal.permute(0,2,1)
65
+
66
+
67
+ def remove_above_nyquist(amplitudes, pitch, sampling_rate):
68
+ ''' amplitudes: (batch, frames, n_harmoincs)
69
+ pitch: (batch, frames, 1)
70
+ '''
71
+ n_harm = amplitudes.shape[-1]
72
+ pitches = pitch.repeat(1,1,n_harm).cumsum(-1)
73
+ aa = (pitches < sampling_rate / 2).float() + 1e-4
74
+ return amplitudes * aa
75
+
76
+
77
+ def remove_above_nyquist_mode(amplitudes, frequencies, sampling_rate):
78
+ ''' amplitudes: (batch, frames, n_harmoincs)
79
+ frequencies: (batch, frames, n_harmonics)
80
+ '''
81
+ aa = (frequencies < sampling_rate / 2).float() + 1e-4
82
+ return amplitudes * aa
83
+
84
+ def scale_function(x):
85
+ ''' 0 ~ 2'''
86
+ return 2 * torch.sigmoid(x)**(math.log(10)) + 1e-7
87
+
88
+ def extract_loudness(signal, sampling_rate, block_size, n_fft=2048):
89
+ S = li.stft(
90
+ signal,
91
+ n_fft=n_fft,
92
+ hop_length=block_size,
93
+ win_length=n_fft,
94
+ center=True,
95
+ )
96
+ S = np.log(abs(S) + 1e-7)
97
+ f = li.fft_frequencies(sampling_rate, n_fft)
98
+ a_weight = li.A_weighting(f)
99
+
100
+ S = S + a_weight.reshape(-1, 1)
101
+
102
+ S = np.mean(S, 0)[..., :-1]
103
+
104
+ return S
105
+
106
+
107
+ def extract_pitch(signal, sampling_rate, block_size):
108
+ length = signal.shape[-1] // block_size
109
+ f0 = crepe.predict(
110
+ signal,
111
+ sampling_rate,
112
+ step_size=int(1000 * block_size / sampling_rate),
113
+ verbose=1,
114
+ center=True,
115
+ viterbi=True,
116
+ )
117
+ f0 = f0[1].reshape(-1)[:-1]
118
+
119
+ if f0.shape[-1] != length:
120
+ f0 = np.interp(
121
+ np.linspace(0, 1, length, endpoint=False),
122
+ np.linspace(0, 1, f0.shape[-1], endpoint=False),
123
+ f0,
124
+ )
125
+
126
+ return f0
127
+
128
+
129
+ def harmonic_synth(pitch, amplitudes, sampling_rate):
130
+ n_harmonic = amplitudes.shape[-1]
131
+ omega = torch.cumsum(2 * math.pi * pitch / sampling_rate, 1)
132
+ omegas = omega * torch.arange(1, n_harmonic + 1).to(omega)
133
+ signal = (torch.sin(omegas) * amplitudes).sum(-1, keepdim=True)
134
+ return signal
135
+
136
+ def modal_synth(modes, amplitude, sampling_rate, n_chunks=16):
137
+ freqs = modes.chunk(n_chunks, 1)
138
+ coefs = amplitude.chunk(n_chunks, 1)
139
+ lastf = torch.zeros_like(freqs[0])
140
+ sols = []
141
+ for f, c in zip(freqs, coefs):
142
+ fcs = f.cumsum(1) + lastf
143
+ sol = (torch.cos(fcs) * c).sum(-1, keepdim=True)
144
+ lastf = fcs.narrow(1,-1,1)
145
+ sols.append(sol)
146
+ return torch.cat(sols, 1)
147
+
148
+
149
+ def amp_to_impulse_response(amp, target_size):
150
+ amp = torch.stack([amp, torch.zeros_like(amp)], -1)
151
+ amp = torch.view_as_complex(amp)
152
+ amp = fft.irfft(amp)
153
+
154
+ filter_size = amp.shape[-1]
155
+
156
+ amp = torch.roll(amp, filter_size // 2, -1)
157
+ win = torch.hann_window(filter_size, dtype=amp.dtype, device=amp.device)
158
+
159
+ amp = amp * win
160
+
161
+ amp = nn.functional.pad(amp, (0, int(target_size) - int(filter_size)))
162
+ amp = torch.roll(amp, -filter_size // 2, -1)
163
+
164
+ return amp
165
+
166
+
167
+ def fft_convolve(signal, kernel):
168
+ signal = nn.functional.pad(signal, (0, signal.shape[-1]))
169
+ kernel = nn.functional.pad(kernel, (kernel.shape[-1], 0))
170
+
171
+ output = fft.irfft(fft.rfft(signal) * fft.rfft(kernel))
172
+ output = output[..., output.shape[-1] // 2:]
173
+
174
+ return output
175
+
src/utils/misc.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ from scipy.interpolate import RectBivariateSpline
7
+
8
+ from contextlib import contextmanager,redirect_stderr,redirect_stdout
9
+ from os import devnull
10
+
11
+ chars = [c for c in '0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ']
12
+
13
+ @contextmanager
14
+ def suppress_stdout_stderr():
15
+ """A context manager that redirects stdout and stderr to devnull"""
16
+ with open(devnull, 'w') as fnull:
17
+ with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
18
+ yield (err, out)
19
+
20
+ def batchify(x, batch_size, n_samples):
21
+ pass
22
+
23
+ def random_str(length=8):
24
+ return "".join(np.random.choice(chars, length))
25
+
26
+ def sqrt(x):
27
+ return x.pow(.5) if isinstance(x, torch.Tensor) else x**.5
28
+
29
+ def soft_bow(v_rel, a=100):
30
+ return np.sqrt(2*a) * v_rel * torch.exp(-a * v_rel**2 + 0.5)
31
+
32
+ def hard_bow(v_rel, a=5, eps=0.1, hard_sign=True):
33
+ sign = torch.sign(v_rel) if hard_sign else torch.tanh(100 * v_rel)
34
+ return sign * (eps + (1-eps) * torch.exp(-a * v_rel.abs()))
35
+
36
+ def raised_cosine(N, h, ctr, wid, n):
37
+ ''' N (int): number of maximal samples in space
38
+ h (float): spatial grid cell width
39
+ ctr (B,1,1): center points for each batch
40
+ wid (B,1,1): width lengths for each batch
41
+ n (B,): number of actual samples in space
42
+ '''
43
+ xax = torch.linspace(h, 1, N).to(ctr.device).view(1,-1,1) # (1, N, 1)
44
+ ctr = (ctr * n / N)
45
+ wid = wid / N
46
+ ind = torch.sign(torch.relu(-(xax - ctr - wid / 2) * (xax - ctr + wid / 2)))
47
+ out = 0.5 * ind * (1 + torch.cos(2 * np.pi * (xax - ctr) / wid))
48
+ return out / out.abs().sum(1, keepdim=True) # (batch_Size, N, 1)
49
+
50
+ def floor_dirac_delta(n, ctr, N):
51
+ ''' torch::Tensor n, // number of samples in space
52
+ torch::Tensor ctr, // center point of raised cosine curve
53
+ int N
54
+ '''
55
+ xax = torch.ones_like(ctr).view(-1,1,1).repeat(1,N,1).cumsum(1) - 1
56
+ idx = torch.floor(ctr * n).view(-1,1,1)
57
+ #return torch.floor(xax).eq(idx).to(n.dtype()) # (batch_size, N, 1)
58
+ return torch.floor(xax).eq(idx) # (batch_size, N, 1)
59
+
60
+ def triangular(N, n, p_x, p_a):
61
+ ''' N (int): number of maximal samples in space
62
+ n (B, 1, 1): number of actual samples in space
63
+ p_x (B, Nt, 1): peak position
64
+ p_a (B, Nt, 1): peak amplitude
65
+ '''
66
+ vel_l = torch.where(p_x.le(0), torch.zeros_like(p_x), p_a / p_x / n)
67
+ vel_r = torch.where(p_x.le(0), torch.zeros_like(p_x), p_a / (1-p_x) / n)
68
+ vel_l = ((vel_l * torch.ones_like(vel_l).repeat(1,1,N)).cumsum(2) - vel_l).clamp(min=0)
69
+ vel_r = ((vel_r * torch.ones_like(vel_r).repeat(1,1,N)).cumsum(2) - vel_r * (N-n+1)).clamp(min=0).flip(2)
70
+ tri = torch.minimum(vel_l, vel_r)
71
+ assert not torch.isnan(tri).any(), torch.isnan(tri.flatten(1).sum(1))
72
+ return tri
73
+
74
+ def pre_shaper(x, sr, velocity=10):
75
+ w = torch.tanh(torch.ones_like(x).cumsum(-1) / sr * velocity)
76
+ return w * x
77
+
78
+ def post_shaper(x, sr, pulloff, velocity=100):
79
+ offset = x.size(-1) - int(sr * pulloff)
80
+ w = torch.tanh(torch.ones_like(x).cumsum(-1) / sr * velocity).flip(-1)
81
+ w = F.pad(w.narrow(-1,offset,w.size(-1)-offset), (0,offset))
82
+ return w * x
83
+
84
+ def random_uniform(floor, ceiling, size=None, weight=None, dtype=None):
85
+ if not isinstance(size, tuple): size = (size,)
86
+ if weight is None: weight = torch.ones(size, dtype=dtype)
87
+ # NOTE: torch.rand(..., dtype=dtype) for dtype \in [torch.float32, torch.float64]
88
+ # can result in different random number generation
89
+ # (for different precisions; despite fixiing the random seed.)
90
+ return (ceiling - floor) * torch.rand(size=size).to(dtype) * weight + floor
91
+
92
+ def equidistant(floor, ceiling, steps, dtype=None):
93
+ return torch.linspace(floor, ceiling, steps).to(dtype)
94
+
95
+ def get_masks(model_name, bs, disjoint=True):
96
+ ''' setting `disjoint=False` enables multiple excitations allowed
97
+ (e.g., bowing over hammered strings.) While this could be a
98
+ charming choice, but it can also drive the simulation unstable.
99
+ '''
100
+ # boolean mask that determines whether to impose each excitation
101
+ if model_name.endswith('bow'):
102
+ bow_mask = torch.ones( size=(bs,)).view(-1,1,1)
103
+ hammer_mask = torch.zeros(size=(bs,)).view(-1,1,1)
104
+ elif model_name.endswith('hammer'):
105
+ bow_mask = torch.zeros(size=(bs,)).view(-1,1,1)
106
+ hammer_mask = torch.ones( size=(bs,)).view(-1,1,1)
107
+ elif model_name.endswith('pluck'):
108
+ bow_mask = torch.zeros(size=(bs,)).view(-1,1,1)
109
+ hammer_mask = torch.zeros(size=(bs,)).view(-1,1,1)
110
+ else:
111
+ bow_mask = torch.rand(size=(bs,)).gt(0.5).view(-1,1,1)
112
+ hammer_mask = torch.rand(size=(bs,)).gt(0.5).view(-1,1,1)
113
+ if disjoint:
114
+ both_are_true = torch.logical_and(
115
+ torch.logical_or(bow_mask, hammer_mask),
116
+ torch.logical_or(bow_mask, hammer_mask.logical_not())
117
+ )
118
+ hammer_mask[both_are_true] = False
119
+ bow_mask = bow_mask.view(-1,1,1)
120
+ hammer_mask = hammer_mask.view(-1,1,1)
121
+ return [bow_mask, hammer_mask]
122
+
123
+ def f0_interpolate(f0_1, n_frames, tmax):
124
+ t_0 = np.linspace(0, tmax, n_frames)
125
+ t_1 = np.linspace(0, tmax, f0_1.shape[0])
126
+ return np.interp(t_0, t_1, f0_1)
127
+
128
+ def interpolate1d(u, xaxis, xvals, k=5):
129
+ ''' u: (1, Nx)
130
+ xaxis: (1, Nx_input)
131
+ xvals: (1, Nx_output)
132
+ -> (1, Nx_output)
133
+ '''
134
+ t = np.arange(k)[:,None] / k
135
+ rbs = RectBivariateSpline(t, xaxis, u.repeat(k,0), kx=1, ky=k)
136
+ return rbs(t, xvals, grid=True)[k//2][None,:]
137
+
138
+ def interpolate(u, taxis, xaxis, xvals, kx=5, ky=5):
139
+ ''' u: (Nt, Nx)
140
+ taxis: (Nt, 1)
141
+ xaxis: (1, Nx_input)
142
+ xvals: (1, Nx_output)
143
+ -> (Nt, Nx_output)
144
+ '''
145
+ rbs = RectBivariateSpline(taxis, xaxis, u, kx=kx, ky=ky)
146
+ return rbs(taxis, xvals, grid=True)
147
+
148
+ def torch_interpolate(x, scale_factor):
149
+ y = F.interpolate(x, scale_factor=scale_factor)
150
+ res = x.size(-1) - y.size(-1)
151
+ if res % 2 == 0: y = F.pad(y, (res//2, res//2))
152
+ else: y = F.pad(y, (res//2, res//2+1))
153
+ return y
154
+
155
+
156
+ def minmax_normalize(x, dim=-1):
157
+ x_min = x.min(dim, keepdim=True).values
158
+ x = x - x_min
159
+ x_max = x.max(dim, keepdim=True).values
160
+ x = x / x_max
161
+ return x
162
+
163
+ def get_minmax(x):
164
+ if np.isnan(x.sum()):
165
+ return None, None
166
+ return np.nan_to_num(x.min()), np.nan_to_num(x.max())
167
+
168
+ def select_with_batched_index(input, dim, index):
169
+ ''' input: (bs, ..., n, ...)
170
+ dim : (int)
171
+ index: (bs, ..., 1, ...) index to select on dim `dim`
172
+ -> out : (bs, ..., 1, ...) for each batch, select `index`-th element on dim `dim`
173
+ '''
174
+ assert input.size(0) == index.size(0), [input.shpae, index.shape]
175
+ bs = input.size(0)
176
+ ins = input.chunk(bs, 0)
177
+ idx = index.chunk(bs, 0)
178
+ out = []
179
+ for b in range(bs):
180
+ out.append(batched_index_select(ins[b], dim, idx[b]))
181
+ return torch.cat(out, dim=0)
182
+
183
+ def batched_index_select(input, dim, index):
184
+ ''' input: (..., n, ...)
185
+ dim : (int)
186
+ index: (..., k, ...) index to select on dim `dim`
187
+ -> out : (..., k, ...) select k out of n elements on dim `dim`
188
+ '''
189
+ Nx = len(list(input.shape))
190
+ expanse = [-1 if k==(dim % Nx) else 1 for k in range(Nx)]
191
+ tiler = [1 if k==(dim % Nx) else n for k, n in enumerate(input.shape)]
192
+ index = index.to(torch.int64).view(expanse).tile(tiler)
193
+ return torch.gather(input, dim, index)
194
+
195
+ def random_index(max_N, idx_N):
196
+ if max_N < idx_N:
197
+ # choosing with replacement
198
+ return torch.randint(0, max_N, (idx_N,))
199
+ else:
200
+ # choosing without replacement
201
+ return torch.randperm(max_N)[:idx_N]
202
+
203
+ def ell_infty_normalize(x, normalize_dims=1):
204
+ eps = torch.finfo(x.dtype).eps
205
+ x_shape = list(x.shape)
206
+ m_shape = x_shape[:normalize_dims] + [1] * (len(x_shape) - normalize_dims)
207
+ x_max = x.abs().flatten(normalize_dims).max(normalize_dims).values + eps
208
+ x_gain = 1. / x_max.view(m_shape)
209
+ return x * x_gain, x_gain
210
+
211
+ def sinusoidal_embedding(x, n, gain=10000, dim=-1):
212
+ ''' let `x` be normalized to be in the nondimensional (0 ~ 1) range '''
213
+ assert n % 2 == 0, n
214
+ x = x.unsqueeze(-1)
215
+ shape = [1] * len(list(x.shape)); shape[dim] = -1 # e.g., [1,1,-1]
216
+ half_n = n // 2
217
+
218
+ expnt = torch.arange(half_n, device=x.device, dtype=x.dtype).view(shape)
219
+ _embed = torch.exp(expnt * -(np.log(gain) / (half_n - 1)))
220
+ _embed = torch.exp(expnt * -(np.log(gain) / (half_n - 1)))
221
+ _embed = x * _embed
222
+ emb = torch.cat((torch.sin(_embed), torch.cos(_embed)), dim)
223
+ return emb # list(x.shape) + [n]
224
+
225
+ def fourier_feature(x, B):
226
+ ''' x: (Bs, ..., in_dim)
227
+ B: (in_dim, out_dim)
228
+ '''
229
+ if B is None:
230
+ return x
231
+ else:
232
+ x_proj = (2.*np.pi*x) @ B
233
+ return torch.cat((torch.sin(x_proj), torch.cos(x_proj)), dim=-1)
234
+
235
+ def save_simulation_data(directory, excitation_type, overall_results, constants):
236
+ os.makedirs(directory, exist_ok=True)
237
+ string_params = overall_results.pop('string_params')
238
+ hammer_params = overall_results.pop('hammer_params')
239
+ bow_params = overall_results.pop('bow_params')
240
+ simulation_dict = overall_results
241
+ string_dict = {
242
+ 'kappa': string_params[0],
243
+ 'alpha': string_params[1],
244
+ 'u0' : string_params[2],
245
+ 'v0' : string_params[3],
246
+ 'f0' : string_params[4],
247
+ 'pos' : string_params[5],
248
+ 'T60' : string_params[6],
249
+ 'target_f0': string_params[7],
250
+ }
251
+ hammer_dict = {
252
+ 'x_H' : hammer_params[0],
253
+ 'v_H' : hammer_params[1],
254
+ 'u_H' : hammer_params[2],
255
+ 'w_H' : hammer_params[3],
256
+ 'M_r' : hammer_params[4],
257
+ 'alpha': hammer_params[5],
258
+ }
259
+ bow_dict = {
260
+ 'x_B' : bow_params[0],
261
+ 'v_B' : bow_params[1],
262
+ 'F_B' : bow_params[2],
263
+ 'phi_0': bow_params[3],
264
+ 'phi_1': bow_params[4],
265
+ 'wid_B': bow_params[5],
266
+ }
267
+
268
+ def sample(val):
269
+ try:
270
+ _val = val.item(0)
271
+ except AttributeError as err:
272
+ if isinstance(val, float) or isinstance(val, int):
273
+ _val = val
274
+ else:
275
+ raise err
276
+ return _val
277
+ short_configuration = {
278
+ 'excitation_type': excitation_type,
279
+ 'theta_t' : constants[1],
280
+ 'lambda_c': constants[2],
281
+ }
282
+ short_configuration['value-string'] = {}
283
+ for key, val in string_dict.items():
284
+ short_configuration['value-string'].update({ key : sample(val) })
285
+ short_configuration['value-hammer'] = {}
286
+ for key, val in hammer_dict.items():
287
+ short_configuration['value-hammer'].update({ key : sample(val) })
288
+ short_configuration['value-bow'] = {}
289
+ for key, val in bow_dict.items():
290
+ short_configuration['value-bow'].update({ key : sample(val) })
291
+
292
+ np.savez_compressed(f'{directory}/simulation.npz', **simulation_dict)
293
+ np.savez_compressed(f'{directory}/string_params.npz', **string_dict)
294
+ np.savez_compressed(f'{directory}/hammer_params.npz', **hammer_dict)
295
+ np.savez_compressed(f'{directory}/bow_params.npz', **bow_dict)
296
+
297
+ with open(f"{directory}/simulation_config.yaml", 'w') as f:
298
+ yaml.dump(short_configuration, f, default_flow_style=False)
299
+
300
+ def add_noise(x, c, vals, eps=1e-5):
301
+ noise = eps * torch.randn_like(x)
302
+ for val in vals:
303
+ mask = torch.where(c == val, torch.ones_like(c), torch.zeros_like(c))
304
+ x = x + mask * noise
305
+ return x
306
+
307
+ def downsample(x, factor=None, size=None):
308
+ ''' x: (Bs, Nt) -> (Bs, Nt // factor)
309
+ '''
310
+ if size is None:
311
+ size = x.size(1) // factor + bool(x.size(1) % factor)
312
+ else:
313
+ assert factor is None, [factor, size]
314
+ return F.interpolate(x.unsqueeze(1), size=size, mode='linear').squeeze(1)
315
+
316
+
317
+ if __name__=='__main__':
318
+ N = 10
319
+ B = 1
320
+ h = 1 / N
321
+ ctr = 0.5 * torch.ones(B).view(-1,1,1)
322
+ wid = 1 * torch.ones(B).view(-1,1,1)
323
+ n = N * torch.ones(B)
324
+ ''' N (int): number of maximal samples in space
325
+ h (float): spatial grid cell width
326
+ ctr (B,1,1): center points for each batch
327
+ wid (B,1,1): width lengths for each batch
328
+ n (B,): number of actual samples in space
329
+ '''
330
+ c = raised_cosine(N, h, ctr, wid, n)
331
+ print(c.shape)
332
+ import matplotlib.pyplot as plt
333
+ plt.figure()
334
+ plt.plot(c[0,:,0])
335
+ plt.savefig('asdf.png')
336
+
src/utils/plot.py ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import subprocess
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import librosa
8
+ import matplotlib.pyplot as plt
9
+ import scipy
10
+ from src.utils.control import *
11
+ from src.utils.misc import soft_bow, hard_bow, sinusoidal_embedding
12
+ from src.utils.audio import rms_normalize
13
+ import soundfile as sf
14
+
15
+ plt.rc('text', usetex=True)
16
+ plt.rc('font', family='serif')
17
+
18
+ def gt_param(TF=5, sr=44100):
19
+ sr = 44100
20
+ NF = int(sr * TF)
21
+ k = 1 / sr
22
+ TRANS = int(0.05 * sr)
23
+ x_bow = torch.linspace(0.25, 0.45, NF)
24
+ v_bow = 0.1 * torch.tanh(torch.linspace(0., 10, NF))
25
+ F_bow = torch.cat((
26
+ torch.linspace(100, 120, NF//8 - TRANS), torch.zeros(TRANS),
27
+ 100 * torch.ones(NF//8 - TRANS), torch.zeros(TRANS),
28
+ 100 * torch.ones(NF//8 - TRANS), torch.zeros(TRANS),
29
+ torch.linspace(100, 80, NF//8 - TRANS), torch.zeros(TRANS),
30
+ 80 * torch.ones(NF//4),
31
+ torch.zeros(NF//4),
32
+ ), dim=-1)
33
+
34
+ f0 = torch.cat((
35
+ glissando(98,110, NF//8),
36
+ constant(130.81, NF//8),
37
+ glissando(146.83, 164.81, NF//8),
38
+ constant(207.65, NF//8),
39
+ vibrato(207.65, NF//4, k, 5, 10),
40
+ constant(207.65, NF//4),
41
+ ), dim=-1)
42
+ F_bow = F.pad(F_bow, (NF-F_bow.size(-1),0))
43
+ f0 = F.pad(f0, (NF-f0.size(-1),0))
44
+
45
+ #wid = torch.linspace(0.05, 0.05, NF)
46
+ #rp = np.array([0.3, 0.7])
47
+ #T60 = np.array([[100, 8], [2000, 5]])
48
+
49
+ return [x_bow, v_bow, F_bow, f0]
50
+
51
+ def param(est_param, gt_param, save_path):
52
+ e_x_bow, e_v_bow, e_F_bow, e_f0 = [item.detach().cpu().numpy() for item in est_param[:4]]
53
+ g_x_bow, g_v_bow, g_F_bow, g_f0 = [item.cpu().numpy() for item in gt_param]
54
+
55
+ fig, ax = plt.subplots(figsize=(7,7), nrows=4, ncols=1)
56
+
57
+ ax[0].plot(g_x_bow, 'b:')
58
+ ax[0].plot(e_x_bow, 'k-')
59
+ ax[0].axhline(y=0, c='k', lw=.5)
60
+ ax[0].set_ylabel('bow pos')
61
+
62
+ ax[1].plot(g_v_bow, 'b:')
63
+ ax[1].plot(e_v_bow, 'k-')
64
+ ax[1].axhline(y=0, c='k', lw=.5)
65
+ ax[1].set_ylabel('bow vel')
66
+
67
+ ax[2].plot(g_F_bow, 'b:')
68
+ ax[2].plot(e_F_bow, 'k-')
69
+ ax[2].axhline(y=0, c='k', lw=.5)
70
+ ax[2].set_ylabel('bow force')
71
+
72
+ ax[3].plot(g_f0, 'b:')
73
+ ax[3].plot(e_f0, 'k-')
74
+ ax[3].axhline(y=0, c='k', lw=.5)
75
+ ax[3].set_ylabel('f0')
76
+
77
+ plt.tight_layout()
78
+ plt.savefig(save_path)
79
+ plt.clf()
80
+ plt.close()
81
+
82
+
83
+ def simulation_data(
84
+ save_dir,
85
+ uout, zout, v_r_out, F_H_out, u_H_out,
86
+ state_u, state_z,
87
+ string_params, bow_params, hammer_params,
88
+ **kwargs,
89
+ ):
90
+ N = min(1000, uout.shape[0])
91
+
92
+ kappa, alpha, u0, v0, f0, pos, T60, target_f0 = string_params
93
+ x_b, v_b, F_b, phi_0, phi_1, wid_b = bow_params
94
+ x_H, v_H, u_H, w_H, M_r, alpha_H = hammer_params
95
+
96
+ max_disp = np.max(np.abs(uout[:N]))
97
+ rels = torch.linspace(-1,1,100)
98
+ prof = hard_bow(rels, phi_0, phi_1)
99
+
100
+ # plot string params
101
+ fig, ax = plt.subplots(figsize=(7,7), nrows=5, ncols=1)
102
+
103
+ ax[0].plot(f0, 'k-')
104
+ ax[0].axhline(y=0, c='k', lw=.5)
105
+ ax[0].set_ylabel('f0')
106
+ ax[0].yaxis.tick_right()
107
+ ax[0].set_ylim([0, 500])
108
+
109
+ ax[1].plot(np.linspace(0,1,state_u.shape[-1]), state_u[-1], 'k-')
110
+ ax[1].axvline(x=pos, c='r', lw=.5); ax[1].axvline(x=x_b[-1], c='b', lw=.5)
111
+ ax[1].set_ylabel('transverse state')
112
+ ax[1].yaxis.tick_right()
113
+ #ax[1].set_ylim([-max_disp, max_disp])
114
+
115
+ ax[2].plot(np.linspace(0,1,state_z.shape[-1]), state_z[-1], 'k-')
116
+ ax[2].axvline(x=pos, c='r', lw=.5); ax[1].axvline(x=x_b[-1], c='b', lw=.5)
117
+ ax[2].set_ylabel('longitudinal state')
118
+ ax[2].yaxis.tick_right()
119
+ #ax[2].set_ylim([-max_disp, max_disp])
120
+
121
+ ax[3].plot(np.arange(N), uout[:N], 'k-')
122
+ ax[3].axhline(y=0, c='k', lw=.5)
123
+ ax[3].set_ylabel('output')
124
+ ax[3].yaxis.tick_right()
125
+ ax[3].set_ylim([-max_disp, max_disp])
126
+
127
+ ax[4].plot(np.arange(N), zout[:N], 'k-')
128
+ ax[4].axhline(y=0, c='k', lw=.5)
129
+ ax[4].set_ylabel('output')
130
+ ax[4].yaxis.tick_right()
131
+ #ax[4].set_ylim([-max_disp, max_disp])
132
+
133
+ plt.tight_layout()
134
+ plt.savefig(f"{save_dir}/string.png")
135
+ plt.clf()
136
+ plt.close()
137
+
138
+ # plot bow params
139
+ fig, ax = plt.subplots(figsize=(7,7), nrows=3, ncols=2)
140
+
141
+ ax[0,0].plot(x_b, 'k-') ; ax[0,1].plot(rels.numpy(), prof.numpy(), 'k-')
142
+ ax[0,0].axhline(y=0, c='k', lw=.5) ; ax[0,1].axhline(y=0, c='k', lw=.5)
143
+ ax[0,0].set_ylabel('bowing position') ; ax[0,1].set_ylabel('bow friction fn')
144
+ ax[0,0].yaxis.tick_right() ; ax[0,1].yaxis.tick_right()
145
+ ax[0,0].set_ylim([0, 1]) ; ax[0,1].set_ylim([-1.5, 1.5])
146
+
147
+ ax[1,0].plot(v_b, 'k-') ; ax[1,1].plot(np.arange(N), v_r_out[:N], 'k-')
148
+ ax[1,0].axhline(y=0, c='k', lw=.5) ; ax[1,1].axhline(y=0, c='k', lw=.5)
149
+ ax[1,0].set_ylabel('bowing velocity') ; ax[1,1].set_ylabel('rel vel (attack)')
150
+ ax[1,0].yaxis.tick_right() ; ax[1,1].yaxis.tick_right()
151
+ ax[1,0].set_ylim([0, 0.5]) ; ax[1,1].set_ylim([-2, 2])
152
+
153
+ ax[2,0].plot(F_b, 'k-') ; ax[2,1].plot(np.arange(N), v_r_out[-N:], 'k-')
154
+ ax[2,0].axhline(y=0, c='k', lw=.5) ; ax[2,1].axhline(y=0, c='k', lw=.5)
155
+ ax[2,0].set_ylabel('bowing force') ; ax[2,1].set_ylabel('rel vel (release)')
156
+ ax[2,0].yaxis.tick_right() ; ax[2,1].yaxis.tick_right()
157
+ ax[2,0].set_ylim([0, 100]) ; ax[2,1].set_ylim([-2, 2])
158
+
159
+ plt.tight_layout()
160
+ plt.savefig(f"{save_dir}/bow.png")
161
+ plt.clf()
162
+ plt.close()
163
+
164
+
165
+ sr = 48000
166
+ Nt = len(v_r_out)
167
+ Nx = state_u.shape[-1]
168
+ a_f = (v_r_out[1:] - v_r_out[:Nt-1]) * sr
169
+ F_f = a_f / Nx
170
+ mu = F_f / F_b[-(Nt-1):]
171
+ vr = v_r_out[:Nt-1]
172
+ rels = torch.linspace(np.min(vr)-.1,np.max(vr)+.1,100)
173
+ prof = hard_bow(rels, phi_0, phi_1)
174
+ #prof = soft_bow(rels, phi_0)
175
+
176
+ fig, ax = plt.subplots(figsize=(4,4), nrows=1, ncols=1)
177
+ #ax.plot(rels.numpy(), prof.numpy(), 'r--')
178
+ ax.fill_between(rels.numpy(), prof.numpy(), alpha=0.2, facecolor='r')
179
+ ax.plot(vr, mu, 'k-')
180
+ ax.axhline(y=0, c='k', lw=.5)
181
+ ax.set_xlabel('Relative velocity')
182
+ ax.set_ylabel('Friction coefficient')
183
+ ax.set_ylim([-1.5, 1.5])
184
+
185
+ plt.tight_layout()
186
+ plt.savefig(f"{save_dir}/bow-velforce.pdf")
187
+ plt.clf()
188
+ plt.close()
189
+
190
+ # plot string params
191
+ fig, ax = plt.subplots(figsize=(7,7), nrows=2, ncols=1)
192
+
193
+ sr = 48000
194
+ # ms
195
+ t_1 = 0; Nt_1 = int(sr * t_1 * 1e-3)
196
+ #t_2 = 3; Nt_2 = int(sr * t_2 * 1e-3)
197
+ t_2 = 8; Nt_2 = int(sr * t_2 * 1e-3)
198
+ time = np.linspace(t_1, t_2, Nt_2 - Nt_1)
199
+ ax[0].plot(time, u_H_out[Nt_1:Nt_2], 'k-')
200
+ ax[0].axhline(y=0, c='k', lw=.5)
201
+ ax[0].set_ylabel('hammer displacement')
202
+ ax[0].yaxis.tick_right()
203
+ #ax[0].set_ylim([0, 0.1])
204
+
205
+ ax[1].plot(time, F_H_out[Nt_1:Nt_2], 'k-')
206
+ ax[1].axhline(y=0, c='k', lw=.5)
207
+ ax[1].set_ylabel('hammer force')
208
+ ax[1].yaxis.tick_right()
209
+ #ax[1].set_ylim([0, 10000])
210
+
211
+ plt.tight_layout()
212
+ plt.savefig(f"{save_dir}/hammer.png")
213
+ plt.clf()
214
+ plt.close()
215
+
216
+
217
+ def state_specs(save_path, analytic, estimate, simulate):
218
+ tf = 100
219
+ Nt, Nx = simulate.shape
220
+ nt = Nt // tf
221
+ nx = Nx // 2
222
+ diff_ana = analytic - simulate
223
+ diff_est = estimate - simulate
224
+
225
+ maxval = np.max(np.abs(simulate))
226
+ maxerr = max(np.max(np.abs(diff_ana)), np.max(np.abs(diff_est)))
227
+
228
+ nrows = 3; ncols = 2
229
+ fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(7,7))
230
+ s_state = librosa.display.specshow(simulate[0::tf].T, cmap='coolwarm', ax=ax[0,0])
231
+ a_state = librosa.display.specshow(analytic[0::tf].T, cmap='coolwarm', ax=ax[1,0])
232
+ e_state = librosa.display.specshow(estimate[0::tf].T, cmap='coolwarm', ax=ax[2,0])
233
+
234
+ a_diffs = librosa.display.specshow(diff_ana[0::tf].T, cmap='coolwarm', ax=ax[1,1])
235
+ e_diffs = librosa.display.specshow(diff_est[0::tf].T, cmap='coolwarm', ax=ax[2,1])
236
+
237
+ ax[0,1].plot(simulate[:nt,nx], c='goldenrod', label='FDTD')
238
+ ax[0,1].plot(analytic[:nt,nx], c='r', label='Modal')
239
+ ax[0,1].plot(estimate[:nt,nx], c='g', label='Ours')
240
+
241
+ a_state.set_clim([-maxval, +maxval])
242
+ e_state.set_clim([-maxval, +maxval])
243
+ s_state.set_clim([-maxval, +maxval])
244
+ a_diffs.set_clim([-maxerr, +maxerr])
245
+ e_diffs.set_clim([-maxerr, +maxerr])
246
+ titles = ['FDTD', 'Modal', 'Ours']
247
+ for i, title in enumerate(titles):
248
+ ax[i,0].set_ylabel(title)
249
+ for i in range(nrows):
250
+ for j in range(ncols):
251
+ ax[i,j].set_xticks([])
252
+ ax[i,j].set_yticks([])
253
+
254
+ ax[0,1].legend(
255
+ loc='lower center', bbox_to_anchor=(.95,-0.5),
256
+ ncol=1, fancybox=True,
257
+ handlelength=1., handletextpad=0.1, columnspacing=.5, fontsize=7,
258
+ )
259
+
260
+ fig.tight_layout()
261
+ fig.subplots_adjust(wspace=0)
262
+ fig.subplots_adjust(hspace=0)
263
+
264
+ plt.savefig(save_path, bbox_inches='tight')
265
+ plt.close('all')
266
+ plt.clf()
267
+
268
+ def state_video(save_dir, state_u, sr, framerate=100, trim_front=True, verbose=False, prefix=None, fname='output', maxy=None):
269
+ if isinstance(state_u, list):
270
+ state_v = state_u[1]
271
+ state_u = state_u[0]
272
+ else:
273
+ state_v = None
274
+
275
+ if trim_front:
276
+ state_u = state_u[:int(sr / 55)] # for 55 Hz (A1)
277
+ state_v = state_v[:int(sr / 55)] if state_v is not None else None
278
+ downs = int(state_u.shape[0]/framerate)
279
+ else:
280
+ downs = 100
281
+
282
+ Nt, Nx = state_u.shape
283
+ maxy = np.max(np.abs(state_u)) if maxy is None else maxy
284
+ locs = np.linspace(0, 1, Nx)
285
+ for j in range(Nt // downs):
286
+
287
+ plt.figure(figsize=(5,2))
288
+ if state_v is not None:
289
+ plt.plot(locs, state_v[j * downs], c='k', alpha=0.5)
290
+ plt.plot(locs, state_u[j * downs], c='k')
291
+ plt.xlim([0, 1])
292
+ plt.ylim([-maxy, maxy])
293
+ plt.xticks([])
294
+ plt.yticks([])
295
+
296
+ plt.tight_layout()
297
+ os.makedirs(f'{save_dir}/temp', exist_ok=True)
298
+ plt.savefig(f'{save_dir}/temp/file%02d.png' % j)
299
+ plt.clf()
300
+ plt.close("all")
301
+
302
+ prefix = 'fdtd' if prefix is None else prefix
303
+ with open(os.devnull, 'w') as devnull:
304
+ silent_video = ['ffmpeg',
305
+ '-framerate', f'{framerate}',
306
+ '-i', f'{save_dir}/temp/file%02d.png',
307
+ '-r', '30', '-pix_fmt', 'yuv420p', '-y',
308
+ f'{save_dir}/{prefix}-{fname}-silent_video.mp4']
309
+ output_video = ['ffmpeg',
310
+ '-i', f'{save_dir}/{prefix}-{fname}-silent_video.mp4',
311
+ '-i', f'{save_dir}/{fname}.wav',
312
+ '-c:v', 'copy', '-map', '0:v', '-map', '1:a',
313
+ '-shortest', '-y',
314
+ f'{save_dir}/{prefix}-{fname}.mp4']
315
+ silent_video += ['-loglevel', 'quiet'] if not verbose else []
316
+ output_video += ['-loglevel', 'quiet'] if not verbose else []
317
+ subprocess.call(silent_video, stdout=devnull)
318
+ subprocess.call(output_video, stdout=devnull)
319
+
320
+ shutil.rmtree(f"{save_dir}/temp")
321
+
322
+ def rainbowgram(
323
+ save_path, out, sr, n_fft=2**13, hop_length=None,
324
+ f0_input=None, f0_estimate=None, modes=None, colorbar=True,
325
+ ):
326
+ L = 32
327
+ if out.shape[-1] > 2*n_fft:
328
+ hop_length = n_fft // L if hop_length is None else hop_length
329
+ else:
330
+ n_fft = out.shape[-1] // 2
331
+ hop_length = n_fft // L
332
+ t_max = out.shape[-1] / sr
333
+
334
+ out, gain = rms_normalize(out)
335
+ D = librosa.stft(out, n_fft=n_fft, hop_length=hop_length, pad_mode='reflect')
336
+ mag, phase = librosa.magphase(D)
337
+
338
+ freqs = librosa.fft_frequencies(sr=sr, n_fft=n_fft)
339
+ times = librosa.times_like(D, sr=sr, hop_length=hop_length)
340
+
341
+ phase_exp = 2 * np.pi * np.multiply.outer(freqs, times)
342
+ unwrapped_phase = np.unwrap((np.angle(phase)-phase_exp) / (L/4), axis=1)
343
+ unwrapped_phase_diff = np.diff(unwrapped_phase, axis=1, prepend=0)
344
+
345
+ alpha = librosa.amplitude_to_db(mag, ref=np.max) / 80 + 1
346
+
347
+ #width = 2.5; height = 1.9
348
+ width = 7; height = 7
349
+ fig, ax = plt.subplots(figsize=(width,height))
350
+ spec = librosa.display.specshow(
351
+ unwrapped_phase_diff, cmap='hsv', alpha=alpha,
352
+ n_fft=n_fft, hop_length=hop_length, sr=sr, ax=ax,
353
+ y_axis='log', x_axis='time',
354
+ )
355
+ ax.set_facecolor('#000')
356
+ if colorbar:
357
+ cbar = fig.colorbar(spec, ticks=[-np.pi, -np.pi/2, 0, np.pi/2, np.pi], ax=ax)
358
+ cbar.ax.set(yticklabels=['$-\pi$', '$-\pi/2$', "$0$", '$\pi/2$', '$\pi$']);
359
+
360
+ def add_plot(freqs, label=None, ls=None, lw=2., dashes=(None,None)):
361
+ x = np.linspace(1/sr, t_max, freqs.shape[-1])
362
+ freqs = np.interp(times, x, freqs)
363
+ line, = ax.plot(times - times[0], freqs, label=label, color='white', lw=lw, ls=ls, dashes=dashes)
364
+ return line
365
+
366
+ freq_ticks = [0, 128, 512, 2048, 8192, sr // 2]
367
+ time_ticks = [0, 1, 2]
368
+ if f0_input is not None:
369
+ add_plot(f0_input, "f0_input", dashes=(10,5))
370
+ freq_ticks += [f0_input[0]]
371
+
372
+ if f0_estimate is not None:
373
+ add_plot(f0_estimate, "f0_estimate", dashes=(2,5))
374
+ freq_ticks += [] if f0_input is not None else [f0_estimate[0]]
375
+
376
+ if modes is not None:
377
+ for im, m in enumerate(modes):
378
+ l = add_plot(m, f"mode {im}")
379
+ l.set_dashes([5,10,1,10])
380
+
381
+ #ax.set_xticks(time_ticks)
382
+ #ax.set_yticks(freq_ticks)
383
+ ax.set_xticks([])
384
+ ax.set_yticks([])
385
+ ax.xaxis.set_visible(False)
386
+ ax.yaxis.set_visible(False)
387
+
388
+ plt.tight_layout()
389
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=-1e-6)
390
+ plt.clf()
391
+ plt.close("all")
392
+
393
+ def phase_diagram(
394
+ save_path, x, s,
395
+ xmin, xmax,
396
+ dxmin, dxmax,
397
+ ddxmin, ddxmax,
398
+ sr, tau=1, label='$u$'):
399
+ dx = (x[tau:] - x[:-tau]) / (tau / sr)
400
+ ddx = (x[2*tau:] - 2*x[tau:-tau] + x[:-2*tau]) / (2*tau / sr)
401
+
402
+ if s is not None:
403
+ if s.shape[0] > x.shape[0]:
404
+ s = s[:x.shape[0]]
405
+ dsdt = (s[tau:] - s[:-tau]) / (tau / sr)
406
+ _dsdt = np.mean(np.abs(dsdt), axis=0)
407
+ spax = np.arange(len(_dsdt))
408
+
409
+ if s is not None:
410
+ fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(8,3.5), width_ratios=[4, 1])
411
+ ax[0,0].axhline(y=0, color='gray', ls='-', lw=0.3)
412
+ ax[0,0].plot(x, 'k-', lw=0.5)
413
+ ax[0,0].set_xlim([0,len(x)])
414
+ ax[0,0].set_ylim([xmin,xmax])
415
+ ax[0,0].set_xticks([])
416
+ ax[0,0].set_yticks([])
417
+ #ax[0,0].set_xlabel('$t$')
418
+ ax[0,0].set_ylabel(label)
419
+
420
+ ax[0,1].axhline(y=0, color='gray', ls='-', lw=0.3)
421
+ ax[0,1].axvline(x=0, color='gray', ls='-', lw=0.3)
422
+ ax[0,1].plot(dx, x[tau:], 'k-', lw=0.5)
423
+ ax[0,1].set_xlim([dxmin,dxmax])
424
+ ax[0,1].set_ylim([xmin,xmax])
425
+ ax[0,1].set_xticks([])
426
+ ax[0,1].set_yticks([])
427
+ #ax[0,1].set_xlabel('$d$'+label+'$/dt$')
428
+
429
+ #state = librosa.display.specshow(s.T, cmap='coolwarm', ax=ax[1,0])
430
+ state = librosa.display.specshow(dsdt.T, cmap='coolwarm', ax=ax[1,0])
431
+ maxabs = np.max(np.abs(dsdt))
432
+ state.set_clim([-maxabs, +maxabs])
433
+ ax[1,0].set_xlim([0,x.shape[0]])
434
+ ax[1,0].set_xlabel('$t$')
435
+ ax[1,0].set_ylabel('$x$')
436
+
437
+ _dsdt = np.pad( _dsdt, (1,1))
438
+ _spax = np.pad( spax, (1,1), mode='edge')
439
+ ax[1,1].fill_between(+ _dsdt, _spax, alpha=0.2, facecolor='k')
440
+ ax[1,1].fill_between(- _dsdt, _spax, alpha=0.2, facecolor='k')
441
+ ax[1,1].axvline(x=0, color='k', ls='-', lw=1.0)
442
+ ax[1,1].set_ylim([spax[0], spax[-1]])
443
+ ax[1,1].set_xticks([])
444
+ ax[1,1].set_yticks([])
445
+ ax[1,1].set_xlabel('$d$'+label+'$/dt$')
446
+ else:
447
+ fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(8,2), width_ratios=[4, 1])
448
+ ax[0].axhline(y=0, color='gray', ls='-', lw=0.3)
449
+ ax[0].plot(x, 'k-', lw=0.5)
450
+ ax[0].set_xlim([0,len(x)])
451
+ ax[0].set_ylim([xmin,xmax])
452
+ ax[0].set_xticks([])
453
+ ax[0].set_yticks([])
454
+ ax[0].set_xlabel('$t$')
455
+ ax[0].set_ylabel(label)
456
+
457
+ ax[1].axhline(y=0, color='gray', ls='-', lw=0.3)
458
+ ax[1].axvline(x=0, color='gray', ls='-', lw=0.3)
459
+ ax[1].plot(dx, x[tau:], 'k-', lw=0.5)
460
+ ax[1].set_xlim([dxmin,dxmax])
461
+ ax[1].set_ylim([xmin,xmax])
462
+ ax[1].set_xticks([])
463
+ ax[1].set_yticks([])
464
+ ax[1].set_xlabel('$d$'+label+'$/dt$')
465
+
466
+ plt.tight_layout()
467
+ plt.subplots_adjust(wspace=0.)
468
+ plt.subplots_adjust(hspace=0.)
469
+ plt.savefig(save_path, bbox_inches='tight', transparent=True)
470
+ plt.clf()
471
+ plt.close("all")
472
+
473
+
474
+ #fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(3.5,2))
475
+
476
+ #ax[0].axvline(x=0, color='gray', ls='-', lw=0.3)
477
+ #ax[0].axhline(y=0, color='gray', ls='-', lw=0.3)
478
+ #ax[0].plot(x[2*tau:], ddx, 'k-', lw=0.5)
479
+ #ax[0].set_xlim([xmin,xmax])
480
+ #ax[0].set_ylim([ddxmin,ddxmax])
481
+ #ax[0].set_xticks([])
482
+ #ax[0].set_yticks([])
483
+ #ax[0].set_xlabel(label)
484
+ #ax[0].set_ylabel('$d^2$'+label+'$/dt^2$')
485
+
486
+ #ax[1].axvline(x=0, color='gray', ls='-', lw=0.3)
487
+ #ax[1].axhline(y=0, color='gray', ls='-', lw=0.3)
488
+ #ax[1].plot(dx[tau:], ddx, 'k-', lw=0.5)
489
+ #ax[1].set_xlim([dxmin, dxmax])
490
+ #ax[1].set_ylim([ddxmin,ddxmax])
491
+ #ax[1].set_xticks([])
492
+ #ax[1].set_yticks([])
493
+ #ax[1].set_xlabel('$d$'+label+'$/dt$')
494
+
495
+ #plt.tight_layout()
496
+ #plt.subplots_adjust(wspace=0.)
497
+ #plt.subplots_adjust(hspace=0.)
498
+ #save_dir = save_path.split('/')[:-1]
499
+ #save_name = save_path.split('/')[-1]
500
+ #save_path_2 = '/'.join(save_dir+[save_name.replace('phs', 'dphs')])
501
+ #plt.savefig(save_path_2, bbox_inches='tight', transparent=True)
502
+ #plt.clf()
503
+ #plt.close("all")
504
+
505
+
506
+ def xt_grid_embedding(save_path, x, t, embed_dim=32, t_gain=1e-6, x_gain=1e-2):
507
+ t = t * 1000
508
+
509
+ Bs, _, Nx = x.shape
510
+ Bs, Nt, _ = t.shape
511
+ t_embd = sinusoidal_embedding(t.unsqueeze(-1), n=embed_dim, gain=t_gain) # (Bs, 1,Nx,1,embed_dim)
512
+ x_embd = sinusoidal_embedding(x.unsqueeze(-1), n=embed_dim, gain=x_gain) # (Bs,Nt, 1,1,embed_dim)
513
+
514
+ t_axis = t.squeeze().detach().cpu().numpy()
515
+ x_axis = x.squeeze().detach().cpu().numpy()
516
+ t_embd = t_embd.squeeze().detach().cpu().numpy()
517
+ x_embd = x_embd.squeeze().detach().cpu().numpy()
518
+ assert len(list(t_embd.shape)) == 2, t_embd.shape
519
+ assert len(list(x_embd.shape)) == 2, x_embd.shape
520
+ e = np.arange(embed_dim)
521
+
522
+ fig, ax = plt.subplots(figsize=(13,7), nrows=1, ncols=2)
523
+ librosa.display.specshow(t_embd, ax=ax[0], x_coords=e, y_coords=t_axis)
524
+ librosa.display.specshow(x_embd, ax=ax[1], x_coords=e, y_coords=x_axis)
525
+ ax[0].set_title("t embed")
526
+ ax[0].set_xlabel("embedding dim")
527
+ ax[0].set_ylabel("time")
528
+ ax[0].set_yticks(t_axis[0::10])
529
+
530
+ ax[1].set_title("x embed")
531
+ ax[1].set_xlabel("embedding dim")
532
+ ax[1].set_ylabel("space")
533
+ ax[1].set_yticks(x_axis[0::10])
534
+ ax[1].yaxis.set_label_position("right")
535
+ ax[1].yaxis.tick_right()
536
+
537
+ plt.tight_layout()
538
+ plt.subplots_adjust(wspace=0.)
539
+ plt.subplots_adjust(hspace=0.)
540
+ plt.savefig(save_path)
541
+ plt.clf()
542
+ plt.close("all")
543
+
544
+ def logedc(save_path, logedc, tmax):
545
+ time = np.linspace(0, tmax, logedc.shape[0])
546
+
547
+ fig, ax = plt.subplots(figsize=(3,3))
548
+ ax.plot(time, logedc)
549
+ ax.set_xlabel("Time (s)")
550
+ ax.set_ylabel("Energy (dB)")
551
+
552
+ plt.tight_layout()
553
+ plt.savefig(save_path)
554
+ plt.clf()
555
+ plt.close("all")
556
+
557
+ def f0curve(save_path, f0_input, f0_estimate, first_mode, tmax):
558
+ time = np.linspace(0, tmax, len(f0_estimate))
559
+
560
+ fig, ax = plt.subplots(figsize=(3,3))
561
+ ax.plot(time, f0_input, label='$f_0$')
562
+ ax.plot(time, f0_estimate, label='$f_0^{(\\tt est)}$')
563
+ ax.plot(time, first_mode, label='$\hat{f_0}$')
564
+ ax.set_xlabel("Time (s)")
565
+ ax.set_ylabel("Frequency (Hz)")
566
+ ax.set_ylim(0, 200)
567
+
568
+ plt.legend()
569
+ plt.tight_layout()
570
+ plt.savefig(save_path)
571
+ plt.clf()
572
+ plt.close("all")
573
+
574
+ def spectrum(save_path, out, f0_input, f0_estimate, modes, sr, n_fft=2**14, ylabel=None):
575
+ t_max = out.shape[-1] / sr
576
+ n_fft = min(n_fft, out.shape[-1])
577
+ cr = int(f0_estimate.shape[-1] / t_max) # crepe framerate
578
+ simulated = out[-n_fft:]
579
+ f0_input = f0_input[-1]
580
+ f0_estimate = f0_estimate[-1]
581
+ modes = [m[-1] for m in modes]
582
+
583
+ simulated_fr = 20 * np.log10(np.abs(np.fft.rfft(simulated, n_fft)))
584
+ freqs = np.linspace(0, sr/2 / 1000, int(n_fft/2+1))
585
+
586
+ n_freqs = 1024
587
+
588
+ fig, ax = plt.subplots(figsize=(4,2))
589
+
590
+ lw = 0.7
591
+ ax.plot(freqs[:n_freqs], simulated_fr[:n_freqs], 'k', lw=1.)
592
+ ax.axvline(x=f0_input / 1000, c='r', ls='-', lw=lw, label='$f_0$')
593
+ ax.axvline(x=f0_estimate / 1000, c='g', ls='--', lw=lw, label='$f_0^{(\\tt est)}$')
594
+ for i, m in enumerate(modes):
595
+ if i == 0:
596
+ ax.axvline(x=m / 1000, c='b', ls='-.', lw=lw, label='$\hat{f_p}$')
597
+ else:
598
+ ax.axvline(x=m / 1000, c='b', ls='-.', lw=lw)
599
+
600
+ ax.set_xticks([0, 0.5, 1, 1.5, 2])
601
+ plt.xlim([0, 2])
602
+ plt.xlabel('Frequency (kHz)')
603
+ plt.ylabel(ylabel)
604
+
605
+ plt.legend(ncol=3, fancybox=True)
606
+ plt.tight_layout()
607
+ plt.savefig(save_path, bbox_inches='tight')
608
+ plt.clf()
609
+ plt.close("all")
610
+
611
+
612
+ def spectrum_uz(save_path, uout, zout, f0_input, f0_estimate, modes, sr, n_fft=2**14):
613
+ t_max = uout.shape[-1] / sr
614
+ n_fft = min(n_fft, uout.shape[-1])
615
+ cr = int(f0_estimate.shape[-1] / t_max) # crepe framerate
616
+ simulated_u = uout[-n_fft:]
617
+ simulated_z = zout[-n_fft:]
618
+ f0_input = f0_input[-1]
619
+ f0_estimate = f0_estimate[-1]
620
+ modes = [m[-1] for m in modes]
621
+
622
+ simulated_fr_u = 20 * np.log10(np.abs(np.fft.rfft(simulated_u, n_fft)))
623
+ simulated_fr_z = 20 * np.log10(np.abs(np.fft.rfft(simulated_z, n_fft)))
624
+ freqs = np.linspace(0, sr/2 / 1000, int(n_fft/2+1))
625
+
626
+ n_freqs = 1024
627
+
628
+ fig, ax = plt.subplots(figsize=(2.5,2), ncols=1, nrows=2)
629
+ #fig, ax = plt.subplots(figsize=(4,2), ncols=1, nrows=2)
630
+
631
+ lw = 1.
632
+ lw_fr = .5
633
+ al = .5
634
+ ax[0].axhline(y=0, c='k', lw=0.5, alpha=al)
635
+ ax[0].plot(freqs[:n_freqs], simulated_fr_u[:n_freqs], 'k', lw=lw_fr)
636
+ ax[0].axvline(x=f0_input / 1000, c='r', ls='-', lw=lw, label='$f_0$', alpha=al)
637
+ ax[0].axvline(x=f0_estimate / 1000, c='g', ls='--', lw=lw, label='$f_0^{(\\tt est)}$', alpha=al)
638
+ for i, m in enumerate(modes):
639
+ if i == 0:
640
+ ax[0].axvline(x=m / 1000, c='b', ls=':', lw=lw, label='$\hat{f_p}$', alpha=al)
641
+ else:
642
+ ax[0].axvline(x=m / 1000, c='b', ls=':', lw=lw, alpha=al)
643
+ ax[0].set_xticks([0, 0.5, 1, 1.5, 2])
644
+ ax[0].set_xlim([0, 2])
645
+ ax[0].set_ylabel('$|u|$')
646
+ ax[0].xaxis.set_label_position('top')
647
+ ax[0].yaxis.tick_right()
648
+ ax[0].xaxis.tick_top()
649
+
650
+ ax[1].axhline(y=0, c='k', lw=0.3, alpha=al)
651
+ ax[1].plot(freqs[:n_freqs], simulated_fr_z[:n_freqs], 'k', lw=lw_fr)
652
+ ax[1].axvline(x=f0_input / 1000, c='r', ls='-', lw=lw, label='$f_0$', alpha=al)
653
+ ax[1].axvline(x=f0_estimate / 1000, c='g', ls='--', lw=lw, label='$f_0^{(\\tt est)}$', alpha=al)
654
+ for i, m in enumerate(modes):
655
+ if i == 0:
656
+ ax[1].axvline(x=m / 1000, c='b', ls=':', lw=lw, label='$\hat{f_p}$', alpha=al)
657
+ else:
658
+ ax[1].axvline(x=m / 1000, c='b', ls=':', lw=lw, alpha=al)
659
+ ax[1].set_xticks([])
660
+ ax[1].set_xlim([0, 2])
661
+ ax[1].set_xlabel('Frequency (kHz)')
662
+ ax[1].set_ylabel('$|\zeta|$')
663
+ ax[1].yaxis.tick_right()
664
+ #ax[1].xaxis.set_label_coords(0.2, -0.05)
665
+ #plt.legend(loc='lower center', bbox_to_anchor=(0.7,-0.4), ncol=3, fancybox=True, handletextpad=0.1, columnspacing=1.)
666
+ #ax[1].xaxis.set_label_coords(0.2, -0.1)
667
+ #plt.legend(loc='lower center', bbox_to_anchor=(0.7,-0.8), ncol=3, fancybox=True, handletextpad=0.1, columnspacing=1.)
668
+ ax[1].xaxis.set_label_coords(0.3, -0.1)
669
+ plt.legend(loc='lower center', bbox_to_anchor=(.95,-0.5), ncol=3, fancybox=True, handlelength=1., handletextpad=0.1, columnspacing=.5, fontsize=7)
670
+
671
+ plt.tight_layout()
672
+ plt.subplots_adjust(wspace=0.)
673
+ plt.subplots_adjust(hspace=0.)
674
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=-1e-6)
675
+ plt.clf()
676
+ plt.close("all")
677
+
678
+
679
+ def scatter_xy(save_path, x, y_dict, xlabel, ylabel, xticks=[], yticks=[]):
680
+ fig, ax = plt.subplots(figsize=(2.5,2.5))
681
+ for y_label in y_dict.keys():
682
+ ax.scatter(x, y_dict[y_label], label=y_label, s=1.)
683
+ ax.set_xlabel(xlabel)
684
+ ax.set_ylabel(ylabel)
685
+
686
+ ax.set_xticks(xticks)
687
+ ax.set_yticks(yticks)
688
+
689
+ plt.legend()
690
+ plt.tight_layout()
691
+ plt.savefig(save_path, bbox_inches='tight', transparent=True)
692
+ plt.clf()
693
+ plt.close("all")
694
+
695
+
696
+ def scatter_kappa(save_path, total_summary, ss=.3):
697
+ f0_diffs, f0_ground, kappa, alpha = total_summary
698
+
699
+ def moving_average(x, n):
700
+ assert n % 2 == 1, n
701
+ x = np.pad(x, (n//2, n//2), 'symmetric')
702
+ return np.convolve(x, np.ones(n) / n, 'valid')
703
+
704
+ sorted_kf = sorted(zip(kappa, f0_ground))
705
+ sorted_kappa = [k for k, f in sorted_kf]
706
+ sorted_f0_ground = [f for k, f in sorted_kf]
707
+ sorted_kappa = sorted_kappa[0::40] + [sorted_kappa[-1]]
708
+ sorted_f0_ground = sorted_f0_ground[0::40] + [sorted_f0_ground[-1]]
709
+
710
+ diff_max = max(f0_diffs) + 3.
711
+ xticks = [5,10,15,20]
712
+ yticks = [0,10,20,30,40,50,60]
713
+
714
+ fig, ax = plt.subplots(figsize=(2.5,2), nrows=1, ncols=1)
715
+ #cm = plt.cm.get_cmap('RdYlBu')
716
+ cm = plt.cm.get_cmap('plasma')
717
+
718
+ ax.plot(sorted_kappa, sorted_f0_ground, 'k-', lw=1.0, alpha=0.5)
719
+ sc = ax.scatter(kappa, f0_diffs, c=alpha, s=ss,
720
+ vmin=min(alpha), vmax=max(alpha), cmap=cm)
721
+
722
+ cbar = plt.colorbar(sc)
723
+ cbar.ax.set_title(r'$\alpha$')
724
+ cbar.ax.set_yticks([1,10,20,25])
725
+ ax.set_xticks(xticks)
726
+ ax.set_yticks(yticks)
727
+ ax.set_ylim([0,60])
728
+ for xt in xticks: ax.axvline(xt, c='k', ls='-', lw=0.5, alpha=0.3)
729
+ for yt in yticks: ax.axhline(yt, c='k', ls='-', lw=0.5, alpha=0.3)
730
+ ax.set_xlabel('$\kappa$')
731
+ ax.set_ylabel(r'$|f_0^{(\tt est)} - f_0|$ (Hz)')
732
+ ax.xaxis.tick_top()
733
+
734
+ plt.tight_layout()
735
+ #plt.subplots_adjust(wspace=0.)
736
+ #plt.subplots_adjust(hspace=0.)
737
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=-1e-5)
738
+ plt.clf()
739
+ plt.close("all")
740
+
741
+
742
+
743
+ def scatter_pluck(save_path, total_summary, ss=.3, al=0.7):
744
+ cmap = {
745
+ '$|f_0^{(\\tt est)} - f_0|$' : 'orchid',
746
+ '$|f_0^{(\\tt est)} - \hat{f_0}|$' : 'cadetblue',
747
+ }
748
+
749
+ f0_diffs, kappa, alpha, p_x, p_a = total_summary
750
+
751
+ diff_max = max([max(item) for k, item in f0_diffs.items()]) + 3.
752
+ ncols = 3 if alpha is None else 4
753
+
754
+ fig, ax = plt.subplots(figsize=(4., 2), nrows=1, ncols=ncols)
755
+ # kappa
756
+ for y_label in f0_diffs.keys():
757
+ ax[0].scatter(kappa, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
758
+ ax[0].axvline(x=5.88, c='k', ls='--', lw=0.5)
759
+ #ax[0].axhline(y=6, c='k', ls='--', lw=0.5)
760
+ #ax[0].axhline(y=1, c='k', ls='--', lw=0.5)
761
+ ax[0].set_xlabel('$\kappa$')
762
+ ax[0].set_ylabel('Detune')
763
+ #ax[0].set_ylim([0, 10])
764
+ ax[0].set_ylim([0, diff_max])
765
+ ax[0].set_xticks([2,5,8])
766
+ ax[0].set_yticks([])
767
+ ax[0].xaxis.tick_top()
768
+
769
+ # p_x
770
+ for y_label in f0_diffs.keys():
771
+ ax[1].scatter(p_x, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
772
+ #ax[1].axhline(y=6, c='k', ls='--', lw=0.5)
773
+ #ax[1].axhline(y=1, c='k', ls='--', lw=0.5)
774
+ ax[1].set_xlabel('$p_x$')
775
+ ax[1].set_ylim([0, diff_max])
776
+ ax[1].set_xticks([-0.5, 0])
777
+ ax[1].set_yticks([])
778
+ ax[1].xaxis.tick_top()
779
+ ax[1].yaxis.tick_right()
780
+
781
+ # p_a
782
+ p_a = [x * 1e3 for x in p_a]
783
+ for y_label in f0_diffs.keys():
784
+ ax[2].scatter(p_a, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
785
+ #ax[2].axhline(y=6, c='k', ls='--', lw=0.5)
786
+ #ax[2].axhline(y=1, c='k', ls='--', lw=0.5)
787
+ ax[2].set_xlabel('$p_a\\times10^{3}$')
788
+ ax[2].set_ylim([0, diff_max])
789
+ ax[2].set_xticks([1, 4, 7, 10])
790
+ ax[2].set_yticks([0,5,10])
791
+ ax[2].xaxis.tick_top()
792
+ ax[2].yaxis.tick_right()
793
+
794
+
795
+ # alpha
796
+ if alpha is not None:
797
+ for y_label in f0_diffs.keys():
798
+ ax[3].scatter(alpha, f0_diffs[y_label], c=cmap[y_label], label=y_label, s=ss, alpha=al)
799
+ ax[3].axhline(y=6, c='k', ls='--', lw=0.5)
800
+ ax[3].axhline(y=1, c='k', ls='--', lw=0.5)
801
+ ax[3].set_xlabel('$\\alpha$')
802
+ ax[3].set_ylim([0, diff_max])
803
+ #ax[3].set_xticks([1,2,3,4])
804
+ ax[2].set_yticks([])
805
+ ax[3].set_yticks([0,5,10])
806
+ ax[3].xaxis.tick_top()
807
+
808
+ plt.tight_layout()
809
+ plt.legend(loc='lower center', bbox_to_anchor=(-0.5, -1.2), ncol=2, fancybox=True, handletextpad=0.02, columnspacing=.2, markerscale=5., fontsize=7)
810
+ plt.subplots_adjust(wspace=0.)
811
+ plt.subplots_adjust(hspace=0.)
812
+ plt.savefig(save_path, bbox_inches='tight', transparent=True, pad_inches=-1e-5)
813
+ plt.clf()
814
+ plt.close("all")
815
+
816
+
817
+
818
+ def time_experiment(save_path, gpu_summary, cpu_summary):
819
+
820
+ n_criteria = len(list(gpu_summary.keys()))
821
+ fig, ax = plt.subplots(figsize=(5, 1.66), nrows=1, ncols=n_criteria)
822
+
823
+ config = {
824
+ 'Batch size' : [4, 16, 64, 256, 1024],
825
+ '$N_t$' : [0.25, 0.50, 1.00, 2.00, 4.00],
826
+ '$N_x^{(\\tt t)}+N_x^{(\\tt l)}$' : [20, 40, 80, 160, 320],
827
+ '$N_x^{(\\tt l)}$' : [1, 2, 3, 4],
828
+ }
829
+ xlims = {
830
+ 'Batch size' : [2,1800],
831
+ '$N_t$' : [6000, 300000],
832
+ #'$N_x^{(\\tt t)}$' : [15, 160],
833
+ '$N_x^{(\\tt t)}+N_x^{(\\tt l)}$' : [70, 1900],
834
+ '$N_x^{(\\tt l)}$' : [15, 160],
835
+ }
836
+
837
+ def f0_to_NtNl(f0, k=1/48000, theta_t=0.5 + 2/(np.pi**2), kappa_rel=0.03):
838
+ gamma = 2*f0
839
+ kappa = gamma * kappa_rel
840
+ IHP = (np.pi * kappa / gamma)**2 # inharmonicity parameter (>0); eq 7.21
841
+ K = pow(IHP, .5) * (gamma / np.pi) # set parameters
842
+ h = pow( \
843
+ (gamma**2 * k**2 + pow(gamma**4 * k**4 + 16 * K**2 * k**2 * (2 * theta_t - 1), .5)) \
844
+ / (2 * (2 * theta_t - 1)) \
845
+ , .5)
846
+ N_t = int(1/h)
847
+ alpha = 1
848
+ h = gamma * alpha * k
849
+ N_l = int(1/h)
850
+ return N_t + N_l
851
+ def alpha_to_Nl(alpha, gamma=600, k=1/48000):
852
+ h = gamma * alpha * k
853
+ N_l = int(1/h)
854
+ return N_l
855
+
856
+ for i, criterion in enumerate(config.keys()):
857
+ if criterion == '$N_t$':
858
+ config[criterion] = [int(c * 48000) for c in config[criterion]]
859
+ if criterion == '$N_x^{(\\tt t)}+N_x^{(\\tt l)}$':
860
+ config[criterion] = [f0_to_NtNl(c) for c in config[criterion]]
861
+ if criterion == '$N_x^{(\\tt l)}$':
862
+ config[criterion] = [alpha_to_Nl(c) for c in config[criterion]]
863
+
864
+ print(config)
865
+
866
+ for i, criterion in enumerate(gpu_summary.keys()):
867
+ conf_list = config[criterion]
868
+ gpu_times = gpu_summary[criterion]
869
+ cpu_times = cpu_summary[criterion]
870
+
871
+ if i == 0:
872
+ # divide by number of batch
873
+ #gpu_times = [gpu_times[k] / conf_list[k] for k in range(len(gpu_times))]
874
+ #cpu_times = [cpu_times[k] / conf_list[k] for k in range(len(cpu_times))]
875
+ pass
876
+ elif i > 1:
877
+ conf_list = list(reversed(conf_list))
878
+ gpu_times = list(reversed(gpu_times))
879
+ cpu_times = list(reversed(cpu_times))
880
+ gpu_times = [gpu_times[k] / gpu_times[0] for k in range(len(gpu_times))]
881
+ cpu_times = [cpu_times[k] / cpu_times[0] for k in range(len(cpu_times))]
882
+
883
+ lin_times = [conf_list[k] / conf_list[0] for k in range(len(gpu_times))]
884
+
885
+ thicklw = 0.8
886
+ ax[i].axhline(y=100, c='lightgray', lw=thicklw, ls=':')
887
+ ax[i].axhline(y=10, c='lightgray', lw=thicklw, ls=':')
888
+ ax[i].axhline(y=1, c='lightgray', lw=thicklw, ls='-')
889
+
890
+ ax[i].plot(conf_list[:len(cpu_times)], cpu_times, 'kD--', lw=.9, label="CPU", mfc='lightgray')
891
+ ax[i].plot(conf_list[:len(gpu_times)], gpu_times, 'ko-', lw=.9, label="GPU", mfc='white')
892
+
893
+ ax[i].fill_between(conf_list[:len(cpu_times)], lin_times, alpha=.2)
894
+
895
+ ax[i].set_xlabel(criterion)
896
+ if i == 0:
897
+ ax[i].set_ylabel('Relative time')
898
+
899
+ ax[i].set_xscale('log')
900
+ ax[i].set_yscale('log')
901
+
902
+ ax[i].set_ylim([0.5, 1e3])
903
+
904
+ ax[i].set_xlim(xlims[criterion])
905
+ ax[i].xaxis.set_label_position('top')
906
+ ax[i].yaxis.tick_right()
907
+ if i < len(list(gpu_summary.keys()))-1:
908
+ ax[i].set_yticks([])
909
+ else:
910
+ ax[i].set_yticks([1, 10, 100, 1000])
911
+
912
+ plt.tight_layout()
913
+ #plt.legend(loc='lower center', bbox_to_anchor=(-0.5, -0.75), ncol=2, fancybox=True, handletextpad=0.1, columnspacing=1.)
914
+ #plt.legend(loc='lower right', ncol=2, fancybox=True)
915
+ plt.legend(loc='upper right', ncol=2, fancybox=True)
916
+ plt.subplots_adjust(wspace=0.)
917
+ plt.subplots_adjust(hspace=0.)
918
+ plt.savefig(save_path, bbox_inches='tight', transparent=True)
919
+ plt.clf()
920
+ plt.close("all")
921
+
922
+
923
+ def est_tar_specs(est, tar, inp, plot_path, wave_path, sr=16000):
924
+ data = []
925
+ batch_size = est["wav"].shape[0]
926
+ for b in range(batch_size):
927
+ logspecs = []
928
+ difspecs = []
929
+
930
+ nrows = 4; ncols = 2
931
+ height = 8; widths = 7
932
+ specfig, ax = plt.subplots(nrows, ncols, figsize=(widths,height))
933
+
934
+ diff_0 = tar["logmag"][b] - est["logmag"][b]
935
+ logspecs.append(
936
+ librosa.display.specshow(
937
+ inp["logmag"][b].numpy().T, cmap='magma', ax=ax[0,0]))
938
+ logspecs.append(
939
+ librosa.display.specshow(
940
+ est["logmag"][b].numpy().T, cmap='magma', ax=ax[1,0]))
941
+ logspecs.append(
942
+ librosa.display.specshow(
943
+ tar["logmag"][b].numpy().T, cmap='magma', ax=ax[2,0]))
944
+ difspecs.append(
945
+ librosa.display.specshow(
946
+ diff_0.numpy().T, cmap='bwr', ax=ax[3,0]))
947
+
948
+ diff_0 = tar["logmel"][b] - est["logmel"][b]
949
+ logspecs.append(
950
+ librosa.display.specshow(
951
+ inp["logmel"][b].numpy().T, cmap='magma',ax=ax[0,1]))
952
+ logspecs.append(
953
+ librosa.display.specshow(
954
+ est["logmel"][b].numpy().T, cmap='magma',ax=ax[1,1]))
955
+ logspecs.append(
956
+ librosa.display.specshow(
957
+ tar["logmel"][b].numpy().T, cmap='magma',ax=ax[2,1]))
958
+ difspecs.append(
959
+ librosa.display.specshow(
960
+ diff_0.numpy().T, cmap='bwr', ax=ax[3,1]))
961
+
962
+ for spec in logspecs:
963
+ spec.set_clim([-60, 30])
964
+ for spec in difspecs:
965
+ spec.set_clim([-20, 20])
966
+
967
+ titles = ['Analytic', 'Estimate', 'Original', 'Difference']
968
+ for i, title in enumerate(titles):
969
+ ax[i,0].set_ylabel(title)
970
+
971
+ specfig.tight_layout()
972
+ specfig.subplots_adjust(wspace=0)
973
+ specfig.subplots_adjust(hspace=0)
974
+
975
+ specfig.savefig(plot_path)
976
+ plt.close('all')
977
+ plt.clf()
978
+
979
+ inp_wav = inp["wav"][b].squeeze()
980
+ sf.write(wave_path.replace('.wav', f"-{b}-inp.wav"), inp_wav, samplerate=sr)
981
+
982
+ est_wav = est["wav"][b].squeeze()
983
+ sf.write(wave_path.replace('.wav', f"-{b}-est.wav"), est_wav, samplerate=sr)
984
+ tar_wav = tar["wav"][b].squeeze()
985
+ sf.write(wave_path.replace('.wav', f"-{b}-tar.wav"), tar_wav, samplerate=sr)
986
+
987
+ d = [ wandb.Image(specfig) ]
988
+ d += [ wandb.Audio(inp_wav, sample_rate=sr) ]
989
+ d += [ wandb.Audio(est_wav, sample_rate=sr) ]
990
+ d += [ wandb.Audio(tar_wav, sample_rate=sr) ]
991
+ data.append(d)
992
+
993
+ columns = ["spec"]
994
+ columns += ["analytic", "estimate", "original"]
995
+ return {
996
+ "columns": columns,
997
+ "data": data,
998
+ }
999
+
1000
+
1001
+ def rde_specs(factors, est, sim, plot_path, wave_path, sr=16000):
1002
+ data = []
1003
+ num_factors = len(factors)
1004
+ # plot_path = f'test/plot/rde.png'
1005
+ mag_path = plot_path.replace('rde.png', 'rde-mag.png')
1006
+ mel_path = plot_path.replace('rde.png', 'rde-mel.png')
1007
+ seu_path = plot_path.replace('rde.png', 'rde-state-pinn-u.png')
1008
+ sez_path = plot_path.replace('rde.png', 'rde-state-pinn-z.png')
1009
+ ssu_path = plot_path.replace('rde.png', 'rde-state-fdtd-u.png')
1010
+ ssz_path = plot_path.replace('rde.png', 'rde-state-fdtd-z.png')
1011
+
1012
+ #==============================
1013
+ # plot logmag
1014
+ #==============================
1015
+ specs = []
1016
+ magfig, ax = plt.subplots(nrows=num_factors, ncols=2, figsize=(5,7))
1017
+ for i in range(num_factors):
1018
+ specs.append(librosa.display.specshow(
1019
+ sim["logmag"][i].numpy().T, cmap='magma',ax=ax[i,0]))
1020
+ specs.append(librosa.display.specshow(
1021
+ est["logmag"][i].numpy().T, cmap='magma',ax=ax[i,1]))
1022
+ for spec in specs: spec.set_clim([-60, 30])
1023
+ for i, fc in enumerate(factors): ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
1024
+ ax[0,0].set_title('FDTD')
1025
+ ax[0,1].set_title('PINN')
1026
+ magfig.tight_layout()
1027
+ magfig.subplots_adjust(wspace=0)
1028
+ magfig.subplots_adjust(hspace=0)
1029
+ magfig.savefig(mag_path)
1030
+ plt.close('all')
1031
+ plt.clf()
1032
+
1033
+ #==============================
1034
+ # plot logmel
1035
+ #==============================
1036
+ specs = []
1037
+ melfig, ax = plt.subplots(nrows=num_factors, ncols=2, figsize=(5,7))
1038
+ for i in range(num_factors):
1039
+ specs.append(librosa.display.specshow(
1040
+ sim["logmel"][i].numpy().T, cmap='magma',ax=ax[i,0]))
1041
+ specs.append(librosa.display.specshow(
1042
+ est["logmel"][i].numpy().T, cmap='magma',ax=ax[i,1]))
1043
+ for spec in specs: spec.set_clim([-60, 30])
1044
+ for i, fc in enumerate(factors): ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
1045
+ ax[0,0].set_title('FDTD')
1046
+ ax[0,1].set_title('PINN')
1047
+ melfig.tight_layout()
1048
+ melfig.subplots_adjust(wspace=0)
1049
+ melfig.subplots_adjust(hspace=0)
1050
+ melfig.savefig(mel_path)
1051
+ plt.close('all')
1052
+ plt.clf()
1053
+
1054
+ #==============================
1055
+ # plot state
1056
+ #==============================
1057
+ u_states = []; dustates = []
1058
+ z_states = []; dzstates = []
1059
+
1060
+ eu_fig, eu_ax = plt.subplots(num_factors, 2, figsize=(7,7))
1061
+ ez_fig, ez_ax = plt.subplots(num_factors, 2, figsize=(7,7))
1062
+ su_fig, su_ax = plt.subplots(num_factors, 2, figsize=(7,7))
1063
+ sz_fig, sz_ax = plt.subplots(num_factors, 2, figsize=(7,7))
1064
+
1065
+ u_max = 0
1066
+ z_max = 0
1067
+
1068
+ cm = 'coolwarm'
1069
+ for i, fc in enumerate(factors):
1070
+ e_dif = est["state"][i] - est["state"][-1]
1071
+ s_dif = sim["state"][i] - sim["state"][-1]
1072
+ Nt = int(sr * 30 / 1000)
1073
+ u_states.append(librosa.display.specshow(sim["state"][i][:Nt,:,0].numpy().T, cmap=cm,ax=su_ax[i,0]))
1074
+ u_states.append(librosa.display.specshow(est["state"][i][:Nt,:,0].numpy().T, cmap=cm,ax=eu_ax[i,0]))
1075
+ dustates.append(librosa.display.specshow(s_dif[:Nt,:,0].numpy().T, cmap=cm,ax=su_ax[i,1]))
1076
+ dustates.append(librosa.display.specshow(e_dif[:Nt,:,0].numpy().T, cmap=cm,ax=eu_ax[i,1]))
1077
+
1078
+ z_states.append(librosa.display.specshow(sim["state"][i][:Nt,:,1].numpy().T, cmap=cm,ax=sz_ax[i,0]))
1079
+ z_states.append(librosa.display.specshow(est["state"][i][:Nt,:,1].numpy().T, cmap=cm,ax=ez_ax[i,0]))
1080
+ dzstates.append(librosa.display.specshow(s_dif[:Nt,:,1].numpy().T, cmap=cm,ax=sz_ax[i,1]))
1081
+ dzstates.append(librosa.display.specshow(e_dif[:Nt,:,1].numpy().T, cmap=cm,ax=ez_ax[i,1]))
1082
+
1083
+ u_max = max(u_max, sim["state"][i][:Nt,:,0].abs().max(), est["state"][i][:Nt,:,0].abs().max())
1084
+ z_max = max(z_max, sim["state"][i][:Nt,:,1].abs().max(), est["state"][i][:Nt,:,1].abs().max())
1085
+ su_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
1086
+ eu_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
1087
+ sz_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
1088
+ ez_ax[i,0].set_ylabel(r"$x\times" + f"{fc}$")
1089
+
1090
+ for stat in u_states: stat.set_clim([-u_max, u_max])
1091
+ for stat in z_states: stat.set_clim([-z_max, z_max])
1092
+ for stat in dustates: stat.set_clim([-u_max/10, u_max/10])
1093
+ for stat in dzstates: stat.set_clim([-z_max/10, z_max/10])
1094
+
1095
+ eu_fig.tight_layout(); eu_fig.subplots_adjust(wspace=0); eu_fig.subplots_adjust(hspace=0)
1096
+ ez_fig.tight_layout(); ez_fig.subplots_adjust(wspace=0); ez_fig.subplots_adjust(hspace=0)
1097
+ su_fig.tight_layout(); su_fig.subplots_adjust(wspace=0); su_fig.subplots_adjust(hspace=0)
1098
+ sz_fig.tight_layout(); sz_fig.subplots_adjust(wspace=0); sz_fig.subplots_adjust(hspace=0)
1099
+
1100
+ eu_fig.savefig(seu_path)
1101
+ ez_fig.savefig(sez_path)
1102
+ su_fig.savefig(ssu_path)
1103
+ sz_fig.savefig(ssz_path)
1104
+ plt.close('all')
1105
+ plt.clf()
1106
+
1107
+ for i, factor in enumerate(factors):
1108
+ fstr = f"{factor:.1f}".replace('.', '_')
1109
+ # wave_path = f'test/wave/rde.wav'
1110
+ we_path = wave_path.replace('rde.wav', f'rde-pinn-{fstr}.wav')
1111
+ ws_path = wave_path.replace('rde.wav', f'rde-fdtd-{fstr}.wav')
1112
+ est_wav = est["wav"][i].squeeze()
1113
+ sim_wav = sim["wav"][i].squeeze()
1114
+ sf.write(we_path, est_wav, samplerate=sr)
1115
+ sf.write(ws_path, sim_wav, samplerate=sr)
1116
+
1117
+ d = [ wandb.Image(magfig) ]; columns = ["logmag"]
1118
+ d += [ wandb.Image(melfig) ]; columns += ["logmel"]
1119
+ d += [ wandb.Image(eu_fig) ]; columns += ["PINN-u"]
1120
+ d += [ wandb.Image(su_fig) ]; columns += ["FDTD-u"]
1121
+ d += [ wandb.Image(ez_fig) ]; columns += ["PINN-z"]
1122
+ d += [ wandb.Image(sz_fig) ]; columns += ["FDTD-z"]
1123
+ d += [ wandb.Audio(est_wav, sample_rate=sr) ]; columns += ["PINN wav"]
1124
+ d += [ wandb.Audio(sim_wav, sample_rate=sr) ]; columns += ["FDTD wav"]
1125
+ data.append(d)
1126
+
1127
+ return {
1128
+ "columns": columns,
1129
+ "data": data,
1130
+ }
1131
+
1132
+