Ahsen Khaliq commited on
Commit
0fbd9ed
·
1 Parent(s): c43590a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +269 -0
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/compressor_full.pt")
4
+ os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/reverb_full.pt")
5
+ os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/amp_full.pt")
6
+ os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt")
7
+ os.system("wget https://csteinmetz1.github.io/steerable-nafx/models/delay_full.pt")
8
+
9
+ import sys
10
+ import math
11
+ import torch
12
+ import librosa.display
13
+ import IPython
14
+ import auraloss
15
+ import torchaudio
16
+ import numpy as np
17
+ import scipy.signal
18
+ from google.colab import files
19
+ from tqdm.notebook import tqdm
20
+ from time import sleep
21
+ import matplotlib
22
+ import pyloudnorm as pyln
23
+ import matplotlib.pyplot as plt
24
+ from IPython.display import Image
25
+
26
+ def measure_rt60(h, fs=1, decay_db=30, rt60_tgt=None):
27
+ """
28
+ Analyze the RT60 of an impulse response.
29
+ Args:
30
+ h (ndarray): The discrete time impulse response as 1d array.
31
+ fs (float, optional): Sample rate of the impulse response. (Default: 48000)
32
+ decay_db (float, optional): The decay in decibels for which we actually estimate the time. (Default: 60)
33
+ rt60_tgt (float, optional): This parameter can be used to indicate a target RT60. (Default: None)
34
+ Returns:
35
+ est_rt60 (float): Estimated RT60.
36
+ """
37
+
38
+ h = np.array(h)
39
+ fs = float(fs)
40
+
41
+ # The power of the impulse response in dB
42
+ power = h ** 2
43
+ energy = np.cumsum(power[::-1])[::-1] # Integration according to Schroeder
44
+
45
+ try:
46
+ # remove the possibly all zero tail
47
+ i_nz = np.max(np.where(energy > 0)[0])
48
+ energy = energy[:i_nz]
49
+ energy_db = 10 * np.log10(energy)
50
+ energy_db -= energy_db[0]
51
+
52
+ # -5 dB headroom
53
+ i_5db = np.min(np.where(-5 - energy_db > 0)[0])
54
+ e_5db = energy_db[i_5db]
55
+ t_5db = i_5db / fs
56
+
57
+ # after decay
58
+ i_decay = np.min(np.where(-5 - decay_db - energy_db > 0)[0])
59
+ t_decay = i_decay / fs
60
+
61
+ # compute the decay time
62
+ decay_time = t_decay - t_5db
63
+ est_rt60 = (60 / decay_db) * decay_time
64
+ except:
65
+ est_rt60 = np.array(0.0)
66
+
67
+ return est_rt60
68
+
69
+ def causal_crop(x, length: int):
70
+ if x.shape[-1] != length:
71
+ stop = x.shape[-1] - 1
72
+ start = stop - length
73
+ x = x[..., start:stop]
74
+ return x
75
+
76
+ class FiLM(torch.nn.Module):
77
+ def __init__(
78
+ self,
79
+ cond_dim, # dim of conditioning input
80
+ num_features, # dim of the conv channel
81
+ batch_norm=True,
82
+ ):
83
+ super().__init__()
84
+ self.num_features = num_features
85
+ self.batch_norm = batch_norm
86
+ if batch_norm:
87
+ self.bn = torch.nn.BatchNorm1d(num_features, affine=False)
88
+ self.adaptor = torch.nn.Linear(cond_dim, num_features * 2)
89
+
90
+ def forward(self, x, cond):
91
+
92
+ cond = self.adaptor(cond)
93
+ g, b = torch.chunk(cond, 2, dim=-1)
94
+ g = g.permute(0, 2, 1)
95
+ b = b.permute(0, 2, 1)
96
+
97
+ if self.batch_norm:
98
+ x = self.bn(x) # apply BatchNorm without affine
99
+ x = (x * g) + b # then apply conditional affine
100
+
101
+ return x
102
+
103
+ class TCNBlock(torch.nn.Module):
104
+ def __init__(self, in_channels, out_channels, kernel_size, dilation, cond_dim=0, activation=True):
105
+ super().__init__()
106
+ self.conv = torch.nn.Conv1d(
107
+ in_channels,
108
+ out_channels,
109
+ kernel_size,
110
+ dilation=dilation,
111
+ padding=0, #((kernel_size-1)//2)*dilation,
112
+ bias=True)
113
+ if cond_dim > 0:
114
+ self.film = FiLM(cond_dim, out_channels, batch_norm=False)
115
+ if activation:
116
+ #self.act = torch.nn.Tanh()
117
+ self.act = torch.nn.PReLU()
118
+ self.res = torch.nn.Conv1d(in_channels, out_channels, 1, bias=False)
119
+
120
+ def forward(self, x, c=None):
121
+ x_in = x
122
+ x = self.conv(x)
123
+ if hasattr(self, "film"):
124
+ x = self.film(x, c)
125
+ if hasattr(self, "act"):
126
+ x = self.act(x)
127
+ x_res = causal_crop(self.res(x_in), x.shape[-1])
128
+ x = x + x_res
129
+
130
+ return x
131
+
132
+ class TCN(torch.nn.Module):
133
+ def __init__(self, n_inputs=1, n_outputs=1, n_blocks=10, kernel_size=13, n_channels=64, dilation_growth=4, cond_dim=0):
134
+ super().__init__()
135
+ self.kernel_size = kernel_size
136
+ self.n_channels = n_channels
137
+ self.dilation_growth = dilation_growth
138
+ self.n_blocks = n_blocks
139
+ self.stack_size = n_blocks
140
+
141
+ self.blocks = torch.nn.ModuleList()
142
+ for n in range(n_blocks):
143
+ if n == 0:
144
+ in_ch = n_inputs
145
+ out_ch = n_channels
146
+ act = True
147
+ elif (n+1) == n_blocks:
148
+ in_ch = n_channels
149
+ out_ch = n_outputs
150
+ act = True
151
+ else:
152
+ in_ch = n_channels
153
+ out_ch = n_channels
154
+ act = True
155
+
156
+ dilation = dilation_growth ** n
157
+ self.blocks.append(TCNBlock(in_ch, out_ch, kernel_size, dilation, cond_dim=cond_dim, activation=act))
158
+
159
+ def forward(self, x, c=None):
160
+ for block in self.blocks:
161
+ x = block(x, c)
162
+
163
+ return x
164
+
165
+ def compute_receptive_field(self):
166
+ """Compute the receptive field in samples."""
167
+ rf = self.kernel_size
168
+ for n in range(1, self.n_blocks):
169
+ dilation = self.dilation_growth ** (n % self.stack_size)
170
+ rf = rf + ((self.kernel_size - 1) * dilation)
171
+ return rf
172
+
173
+ # setup the pre-trained models
174
+ model_comp = torch.load("compressor_full.pt", map_location="cpu").eval()
175
+ model_verb = torch.load("reverb_full.pt", map_location="cpu").eval()
176
+ model_amp = torch.load("amp_full.pt", map_location="cpu").eval()
177
+ model_delay = torch.load("delay_full.pt", map_location="cpu").eval()
178
+ model_synth = torch.load("synth2synth_full.pt", map_location="cpu").eval()
179
+
180
+
181
+
182
+ def inference(aud, effect_type):
183
+ x_p, sample_rate = torchaudio.load(aud.file)
184
+
185
+ effect_type = effect_type #@param ["Compressor", "Reverb", "Amp", "Analog Delay", "Synth2Synth"]
186
+ gain_dB = -24 #@param {type:"slider", min:-24, max:24, step:0.1}
187
+ c0 = -1.4 #@param {type:"slider", min:-10, max:10, step:0.1}
188
+ c1 = 3 #@param {type:"slider", min:-10, max:10, step:0.1}
189
+ mix = 70 #@param {type:"slider", min:0, max:100, step:1}
190
+ width = 50 #@param {type:"slider", min:0, max:100, step:1}
191
+ max_length = 30 #@param {type:"slider", min:5, max:120, step:1}
192
+ stereo = True #@param {type:"boolean"}
193
+ tail = True #@param {type:"boolean"}
194
+
195
+ # select model type
196
+ if effect_type == "Compressor":
197
+ pt_model = model_comp
198
+ elif effect_type == "Reverb":
199
+ pt_model = model_verb
200
+ elif effect_type == "Amp":
201
+ pt_model = model_amp
202
+ elif effect_type == "Analog Delay":
203
+ pt_model = model_delay
204
+ elif effect_type == "Synth2Synth":
205
+ pt_model = model_synth
206
+
207
+ # measure the receptive field
208
+ pt_model_rf = pt_model.compute_receptive_field()
209
+
210
+ # crop input signal if needed
211
+ max_samples = int(sample_rate * max_length)
212
+ x_p_crop = x_p[:,:max_samples]
213
+ chs = x_p_crop.shape[0]
214
+
215
+ # if mono and stereo requested
216
+ if chs == 1 and stereo:
217
+ x_p_crop = x_p_crop.repeat(2,1)
218
+ chs = 2
219
+
220
+ # pad the input signal
221
+ front_pad = pt_model_rf-1
222
+ back_pad = 0 if not tail else front_pad
223
+ x_p_pad = torch.nn.functional.pad(x_p_crop, (front_pad, back_pad))
224
+
225
+ # design highpass filter
226
+ sos = scipy.signal.butter(
227
+ 8,
228
+ 20.0,
229
+ fs=sample_rate,
230
+ output="sos",
231
+ btype="highpass"
232
+ )
233
+
234
+ # compute linear gain
235
+ gain_ln = 10 ** (gain_dB / 20.0)
236
+
237
+ # process audio with pre-trained model
238
+ with torch.no_grad():
239
+ y_hat = torch.zeros(x_p_crop.shape[0], x_p_crop.shape[1] + back_pad)
240
+ for n in range(chs):
241
+ if n == 0:
242
+ factor = (width*5e-3)
243
+ elif n == 1:
244
+ factor = -(width*5e-3)
245
+ c = torch.tensor([float(c0+factor), float(c1+factor)]).view(1,1,-1)
246
+ y_hat_ch = pt_model(gain_ln * x_p_pad[n,:].view(1,1,-1), c)
247
+ y_hat_ch = scipy.signal.sosfilt(sos, y_hat_ch.view(-1).numpy())
248
+ y_hat_ch = torch.tensor(y_hat_ch)
249
+ y_hat[n,:] = y_hat_ch
250
+
251
+ # pad the dry signal
252
+ x_dry = torch.nn.functional.pad(x_p_crop, (0,back_pad))
253
+
254
+ # normalize each first
255
+ y_hat /= y_hat.abs().max()
256
+ x_dry /= x_dry.abs().max()
257
+
258
+ # mix
259
+ mix = mix/100.0
260
+ y_hat = (mix * y_hat) + ((1-mix) * x_dry)
261
+
262
+ # remove transient
263
+ y_hat = y_hat[...,8192:]
264
+ y_hat /= y_hat.abs().max()
265
+
266
+ torchaudio.save("output.mp3", y_hat.view(chs,-1), sample_rate, compression=320.0)
267
+ return "output.mp3"
268
+
269
+