cwitkowitz commited on
Commit
883013e
1 Parent(s): 94fc053

Working standalone.

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. app.py +83 -0
  3. model-8750.pt +3 -0
  4. models/__init__.py +0 -0
  5. models/cqt_module.py +281 -0
  6. models/transcriber.py +626 -0
  7. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *__pycache__
2
+ _outputs
3
+ .idea
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyharp import ModelCard, build_endpoint
2
+
3
+ import gradio as gr
4
+ import torchaudio
5
+ import torch
6
+ import os
7
+
8
+ timbre_trap = torch.load('model-8750.pt', map_location='cpu')
9
+
10
+ card = ModelCard(
11
+ name='Timbre-Trap',
12
+ description='De-timbre your audio!',
13
+ author='Frank Cwitkowitz',
14
+ tags=['example', 'music transcription', 'multi-pitch estimation', 'timbre filtering']
15
+ )
16
+
17
+
18
+ def process_fn(audio_path, de_timbre):
19
+ # Load the audio with torchaudio
20
+ audio, fs = torchaudio.load(audio_path)
21
+ # Average channels to obtain mono-channel
22
+ audio = torch.mean(audio, dim=0, keepdim=True)
23
+ # Resample audio to the specified sampling rate
24
+ audio = torchaudio.functional.resample(audio, fs, 22050)
25
+ # Add a batch dimension
26
+ audio = audio.unsqueeze(0)
27
+ # Determine original number of samples
28
+ n_samples = audio.size(-1)
29
+ # Pad audio to next multiple of block length
30
+ audio = timbre_trap.sliCQ.pad_to_block_length(audio)
31
+
32
+ # Encode raw audio into latent vectors
33
+ latents, embeddings, _ = timbre_trap.encode(audio)
34
+ # Apply skip connections if they are turned on
35
+ embeddings = timbre_trap.apply_skip_connections(embeddings)
36
+ # Obtain transcription or reconstructed spectral coefficients
37
+ coefficients = timbre_trap.decode(latents, embeddings, de_timbre)
38
+
39
+ # Invert reconstructed spectral coefficients
40
+ audio = timbre_trap.sliCQ.decode(coefficients)
41
+ # Trim to original number of samples
42
+ audio = audio[..., :n_samples]
43
+ # Remove batch dimension
44
+ audio = audio.squeeze(0)
45
+
46
+ if de_timbre and audio.abs().max():
47
+ # Normalize audio to [-1, 1]
48
+ audio /= audio.abs().max()
49
+
50
+ # Create a temporary directory for output
51
+ os.makedirs('_outputs', exist_ok=True)
52
+ # Create a path for saving the audio
53
+ save_path = os.path.join('_outputs', 'output.wav')
54
+ # Save the audio
55
+ torchaudio.save(save_path, audio, 22050)
56
+
57
+ return save_path
58
+
59
+
60
+ with gr.Blocks() as demo:
61
+ inputs = [
62
+ gr.Audio(
63
+ label='Audio Input',
64
+ type='filepath'
65
+ ),
66
+ #gr.Checkbox(
67
+ # value=False,
68
+ # label='De-Timbre'
69
+ #)
70
+ gr.Slider(
71
+ minimum=0,
72
+ maximum=1,
73
+ step=1,
74
+ value=0,
75
+ label='De-Timbre'
76
+ )
77
+ ]
78
+
79
+ output = gr.Audio(label='Audio Output', type='filepath')
80
+
81
+ ctrls_data, ctrls_button, process_button = build_endpoint(inputs, output, process_fn, card)
82
+
83
+ demo.launch(share=True)
model-8750.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1eb515001ebb871a934379bbd44a22e00a2f41b20c34cd862274aa04c0ca900
3
+ size 11401913
models/__init__.py ADDED
File without changes
models/cqt_module.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchaudio.transforms import AmplitudeToDB
2
+ from cqt_pytorch import CQT as _CQT
3
+
4
+ import numpy as np
5
+ import librosa
6
+ import torch
7
+ import math
8
+
9
+
10
+ class CQT(_CQT):
11
+ """
12
+ Wrapper which adds some basic functionality to the sliCQ module.
13
+ """
14
+
15
+ def __init__(self, n_octaves, bins_per_octave, sample_rate, secs_per_block):
16
+ """
17
+ Instantiate the sliCQ module and wrapper.
18
+
19
+ Parameters
20
+ ----------
21
+ n_octaves : int
22
+ Number of octaves below Nyquist to span
23
+ bins_per_octave : int
24
+ Number of bins allocated to each octave
25
+ sample_rate : int or float
26
+ Number of samples per second of audio
27
+ secs_per_block : float
28
+ Number of seconds to process at a time
29
+ """
30
+
31
+ super().__init__(num_octaves=n_octaves,
32
+ num_bins_per_octave=bins_per_octave,
33
+ sample_rate=sample_rate,
34
+ block_length=int(secs_per_block * sample_rate),
35
+ power_of_2_length=True)
36
+
37
+ self.sample_rate = sample_rate
38
+
39
+ # Compute hop length corresponding to transform coefficients
40
+ self.hop_length = (self.block_length / self.max_window_length)
41
+
42
+ # Compute total number of bins
43
+ self.n_bins = n_octaves * bins_per_octave
44
+ # Determine frequency (MIDI) below Nyquist by specified octaves
45
+ fmin = librosa.hz_to_midi((sample_rate / 2) / (2 ** n_octaves))
46
+
47
+ # Determine center frequency (MIDI) associated with each bin of module
48
+ self.midi_freqs = fmin + np.arange(self.n_bins) / (bins_per_octave / 12)
49
+
50
+ def forward(self, audio):
51
+ """
52
+ Encode a batch of audio into CQT spectral coefficients.
53
+
54
+ Parameters
55
+ ----------
56
+ audio : Tensor (B x 1 X T)
57
+ Batch of input audio
58
+
59
+ Returns
60
+ ----------
61
+ coefficients : Tensor (B x 2 x F X T)
62
+ Batch of real/imaginary CQT coefficients
63
+ """
64
+
65
+ with torch.no_grad():
66
+ # Obtain complex CQT coefficients
67
+ coefficients = self.encode(audio)
68
+
69
+ # Convert complex coefficients to real representation
70
+ coefficients = self.to_real(coefficients)
71
+
72
+ return coefficients
73
+
74
+ @staticmethod
75
+ def to_real(coefficients):
76
+ """
77
+ Convert a set of complex coefficients to equivalent real representation.
78
+
79
+ Parameters
80
+ ----------
81
+ coefficients : Tensor (B x 1 x F X T)
82
+ Batch of complex CQT coefficients
83
+
84
+ Returns
85
+ ----------
86
+ coefficients : Tensor (B x 2 x F X T)
87
+ Batch of real/imaginary CQT coefficients
88
+ """
89
+
90
+ # Collapse channel dimension (mono assumed)
91
+ coefficients = coefficients.squeeze(-3)
92
+ # Convert complex coefficients to real and imaginary
93
+ coefficients = torch.view_as_real(coefficients)
94
+ # Place real and imaginary coefficients under channel dimension
95
+ coefficients = coefficients.transpose(-1, -2).transpose(-2, -3)
96
+
97
+ return coefficients
98
+
99
+ @staticmethod
100
+ def to_complex(coefficients):
101
+ """
102
+ Convert a set of real coefficients to their equivalent complex representation.
103
+
104
+ Parameters
105
+ ----------
106
+ coefficients : Tensor (B x 2 x F X T)
107
+ Batch of real/imaginary CQT coefficients
108
+
109
+ Returns
110
+ ----------
111
+ coefficients : Tensor (B x F X T)
112
+ Batch of complex CQT coefficients
113
+ """
114
+
115
+ # Move real and imaginary coefficients to last dimension
116
+ coefficients = coefficients.transpose(-3, -2).transpose(-2, -1)
117
+ # Convert real and imaginary coefficients to complex
118
+ coefficients = torch.view_as_complex(coefficients.contiguous())
119
+
120
+ return coefficients
121
+
122
+ @staticmethod
123
+ def to_magnitude(coefficients):
124
+ """
125
+ Compute the magnitude for a set of real coefficients.
126
+
127
+ Parameters
128
+ ----------
129
+ coefficients : Tensor (B x 2 x F X T)
130
+ Batch of real/imaginary CQT coefficients
131
+
132
+ Returns
133
+ ----------
134
+ magnitude : Tensor (B x F X T)
135
+ Batch of magnitude coefficients
136
+ """
137
+
138
+ # Compute L2-norm of coefficients to compute magnitude
139
+ magnitude = coefficients.norm(p=2, dim=-3)
140
+
141
+ return magnitude
142
+
143
+ @staticmethod
144
+ def to_decibels(magnitude, rescale=True):
145
+ """
146
+ Convert a set of magnitude coefficients to decibels.
147
+
148
+ TODO - move 0 dB only if maximum is higher?
149
+ - currently it's consistent with previous dB scaling
150
+ - currently it's only used for visualization
151
+
152
+ Parameters
153
+ ----------
154
+ magnitude : Tensor (B x F X T)
155
+ Batch of magnitude coefficients (amplitude)
156
+ rescale : bool
157
+ Rescale decibels to the range [0, 1]
158
+
159
+ Returns
160
+ ----------
161
+ decibels : Tensor (B x F X T)
162
+ Batch of magnitude coefficients (dB)
163
+ """
164
+
165
+ # Initialize a differentiable conversion to decibels
166
+ decibels = AmplitudeToDB(stype='amplitude', top_db=80)(magnitude)
167
+
168
+ if rescale:
169
+ # Make 0 dB ceiling
170
+ decibels -= decibels.max()
171
+ # Rescale decibels to range [0, 1]
172
+ decibels = 1 + decibels / 80
173
+
174
+ return decibels
175
+
176
+ def decode(self, coefficients):
177
+ """
178
+ Invert CQT spectral coefficients to synthesize audio.
179
+
180
+ Parameters
181
+ ----------
182
+ coefficients : Tensor (B x 2 OR 1 x F X T)
183
+ Batch of real/imaginary OR complex CQT coefficients
184
+
185
+ Returns
186
+ ----------
187
+ output : Tensor (B x 1 x T)
188
+ Batch of reconstructed audio
189
+ """
190
+
191
+ with torch.no_grad():
192
+ if not coefficients.is_complex():
193
+ # Convert real coefficients to complex representation
194
+ coefficients = self.to_complex(coefficients)
195
+ # Add a channel dimension to coefficients
196
+ coefficients = coefficients.unsqueeze(-3)
197
+
198
+ # Decode the complex CQT coefficients
199
+ audio = super().decode(coefficients)
200
+
201
+ return audio
202
+
203
+ def pad_to_block_length(self, audio):
204
+ """
205
+ Pad audio to the next multiple of block length such that it can be processed in full.
206
+
207
+ Parameters
208
+ ----------
209
+ audio : Tensor (B x 1 X T)
210
+ Batch of audio
211
+
212
+ Returns
213
+ ----------
214
+ audio : Tensor (B x 1 X T + p)
215
+ Batch of padded audio
216
+ """
217
+
218
+ # Pad the audio with zeros to fill up the remainder of the final block
219
+ audio = torch.nn.functional.pad(audio, (0, -audio.size(-1) % self.block_length))
220
+
221
+ return audio
222
+
223
+ def get_expected_samples(self, t):
224
+ """
225
+ Determine the number of samples corresponding to a specified amount of time.
226
+
227
+ Parameters
228
+ ----------
229
+ t : float
230
+ Amount of time
231
+
232
+ Returns
233
+ ----------
234
+ num_samples : int
235
+ Number of audio samples expected
236
+ """
237
+
238
+ # Compute number of samples and round down
239
+ num_samples = int(max(0, t) * self.sample_rate)
240
+
241
+ return num_samples
242
+
243
+ def get_expected_frames(self, num_samples):
244
+ """
245
+ Determine the number of frames the module will return for a given number of samples.
246
+
247
+ Parameters
248
+ ----------
249
+ num_samples : int
250
+ Number of audio samples available
251
+
252
+ Returns
253
+ ----------
254
+ num_frames : int
255
+ Number of frames expected
256
+ """
257
+
258
+ # Number frames of coefficients per chunk times amount of chunks
259
+ num_frames = math.ceil((num_samples / self.block_length) * self.max_window_length)
260
+
261
+ return num_frames
262
+
263
+ def get_times(self, n_frames):
264
+ """
265
+ Determine the time associated with each frame of coefficients.
266
+
267
+ Parameters
268
+ ----------
269
+ n_frames : int
270
+ Number of frames available
271
+
272
+ Returns
273
+ ----------
274
+ times : ndarray (T)
275
+ Time (seconds) associated with each frame
276
+ """
277
+
278
+ # Compute times as cumulative hops in seconds
279
+ times = np.arange(n_frames) * self.hop_length / self.sample_rate
280
+
281
+ return times
models/transcriber.py ADDED
@@ -0,0 +1,626 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .cqt_module import CQT
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+
7
+ class Transcriber(nn.Module):
8
+ """
9
+ Implements a 2D convolutional U-Net architecture based loosely on SoundStream.
10
+ """
11
+
12
+ def __init__(self, sample_rate, n_octaves, bins_per_octave, secs_per_block=3, latent_size=None, model_complexity=1, skip_connections=False):
13
+ """
14
+ Initialize the full autoencoder.
15
+
16
+ Parameters
17
+ ----------
18
+ sample_rate : int
19
+ Expected sample rate of input
20
+ n_octaves : int
21
+ Number of octaves below Nyquist frequency to represent
22
+ bins_per_octave : int
23
+ Number of frequency bins within each octave
24
+ secs_per_block : float
25
+ Number of seconds to process at once with sliCQ
26
+ latent_size : int or None (Optional)
27
+ Dimensionality of latent space
28
+ model_complexity : int
29
+ Scaling factor for number of filters and embedding sizes
30
+ skip_connections : bool
31
+ Whether to include skip connections between encoder and decoder
32
+ """
33
+
34
+ nn.Module.__init__(self)
35
+
36
+ self.sliCQ = CQT(n_octaves=n_octaves,
37
+ bins_per_octave=bins_per_octave,
38
+ sample_rate=sample_rate,
39
+ secs_per_block=secs_per_block)
40
+
41
+ self.encoder = Encoder(feature_size=self.sliCQ.n_bins, latent_size=latent_size, model_complexity=model_complexity)
42
+ self.decoder = Decoder(feature_size=self.sliCQ.n_bins, latent_size=latent_size, model_complexity=model_complexity)
43
+
44
+ if skip_connections:
45
+ # Start by adding encoder features with identity weighting
46
+ self.skip_weights = torch.nn.Parameter(torch.ones(5))
47
+ else:
48
+ # No skip connections
49
+ self.skip_weights = None
50
+
51
+ def encode(self, audio):
52
+ """
53
+ Encode a batch of raw audio into latent codes.
54
+
55
+ Parameters
56
+ ----------
57
+ audio : Tensor (B x 1 x T)
58
+ Batch of input raw audio
59
+
60
+ Returns
61
+ ----------
62
+ latents : Tensor (B x D_lat x T)
63
+ Batch of latent codes
64
+ embeddings : list of [Tensor (B x C x H x T)]
65
+ Embeddings produced by encoder at each level
66
+ losses : dict containing
67
+ ...
68
+ """
69
+
70
+ # Compute CQT spectral features
71
+ coefficients = self.sliCQ(audio)
72
+
73
+ # Encode features into latent vectors
74
+ latents, embeddings, losses = self.encoder(coefficients)
75
+
76
+ return latents, embeddings, losses
77
+
78
+ def apply_skip_connections(self, embeddings):
79
+ """
80
+ Apply skip connections to encoder embeddings, or discard the embeddings if skip connections do not exist.
81
+
82
+ Parameters
83
+ ----------
84
+ embeddings : list of [Tensor (B x C x H x T)]
85
+ Embeddings produced by encoder at each level
86
+
87
+ Returns
88
+ ----------
89
+ embeddings : list of [Tensor (B x C x H x T)]
90
+ Encoder embeddings scaled with learnable weight
91
+ """
92
+
93
+ if self.skip_weights is not None:
94
+ # Apply a learnable weight to the embeddings for the skip connection
95
+ embeddings = [self.skip_weights[i] * e for i, e in enumerate(embeddings)]
96
+ else:
97
+ # Discard embeddings from encoder
98
+ embeddings = None
99
+
100
+ return embeddings
101
+
102
+ def decode(self, latents, embeddings=None, transcribe=False):
103
+ """
104
+ Decode a batch of latent codes into logits representing real/imaginary coefficients.
105
+
106
+ Parameters
107
+ ----------
108
+ latents : Tensor (B x D_lat x T)
109
+ Batch of latent codes
110
+ embeddings : list of [Tensor (B x C x H x T)] or None (no skip connections)
111
+ Embeddings produced by encoder at each level
112
+ transcribe : bool
113
+ Switch for performing transcription vs. reconstruction
114
+
115
+ Returns
116
+ ----------
117
+ coefficients : Tensor (B x 2 x F X T)
118
+ Batch of output logits [-∞, ∞]
119
+ """
120
+
121
+ # Create binary values to indicate function decoder should perform
122
+ indicator = (not transcribe) * torch.ones_like(latents[..., :1, :])
123
+
124
+ # Concatenate indicator to final dimension of latents
125
+ latents = torch.cat((latents, indicator), dim=-2)
126
+
127
+ # Decode latent vectors into real/imaginary coefficients
128
+ coefficients = self.decoder(latents, embeddings)
129
+
130
+ return coefficients
131
+
132
+ def transcribe(self, audio):
133
+ """
134
+ Obtain transcriptions for a batch of raw audio.
135
+
136
+ Parameters
137
+ ----------
138
+ audio : Tensor (B x 1 x T)
139
+ Batch of input raw audio
140
+
141
+ Returns
142
+ ----------
143
+ activations : Tensor (B x F X T)
144
+ Batch of multi-pitch activations [0, 1]
145
+ """
146
+
147
+ # Encode raw audio into latent vectors
148
+ latents, embeddings, _ = self.encode(audio)
149
+
150
+ # Apply skip connections if they are turned on
151
+ embeddings = self.apply_skip_connections(embeddings)
152
+
153
+ # Estimate pitch using transcription switch
154
+ coefficients = self.decode(latents, embeddings, True)
155
+
156
+ # Extract magnitude of decoded coefficients and convert to activations
157
+ activations = torch.nn.functional.tanh(self.sliCQ.to_magnitude(coefficients))
158
+
159
+ return activations
160
+
161
+ def reconstruct(self, audio):
162
+ """
163
+ Obtain reconstructed coefficients for a batch of raw audio.
164
+
165
+ Parameters
166
+ ----------
167
+ audio : Tensor (B x 1 x T)
168
+ Batch of input raw audio
169
+
170
+ Returns
171
+ ----------
172
+ reconstruction : Tensor (B x 2 x F X T)
173
+ Batch of reconstructed spectral coefficients
174
+ """
175
+
176
+ # Encode raw audio into latent vectors
177
+ latents, embeddings, losses = self.encode(audio)
178
+
179
+ # Apply skip connections if they are turned on
180
+ embeddings = self.apply_skip_connections(embeddings)
181
+
182
+ # Decode latent vectors into spectral coefficients
183
+ reconstruction = self.decode(latents, embeddings)
184
+
185
+ return reconstruction
186
+
187
+ def forward(self, audio, consistency=False):
188
+ """
189
+ Perform all model functions efficiently (for training/evaluation).
190
+
191
+ Parameters
192
+ ----------
193
+ audio : Tensor (B x 1 x T)
194
+ Batch of input raw audio
195
+ consistency : bool
196
+ Whether to perform computations for consistency loss
197
+
198
+ Returns
199
+ ----------
200
+ reconstruction : Tensor (B x 2 x F X T)
201
+ Batch of reconstructed spectral coefficients
202
+ latents : Tensor (B x D_lat x T)
203
+ Batch of latent codes
204
+ transcription : Tensor (B x 2 x F X T)
205
+ Batch of transcription spectral coefficients
206
+ transcription_rec : Tensor (B x 2 x F X T)
207
+ Batch of reconstructed spectral coefficients for transcription coefficients input
208
+ transcription_scr : Tensor (B x 2 x F X T)
209
+ Batch of transcription spectral coefficients for transcription coefficients input
210
+ losses : dict containing
211
+ ...
212
+ """
213
+
214
+ # Encode raw audio into latent vectors
215
+ latents, embeddings, losses = self.encode(audio)
216
+
217
+ # Apply skip connections if they are turned on
218
+ embeddings = self.apply_skip_connections(embeddings)
219
+
220
+ # Decode latent vectors into spectral coefficients
221
+ reconstruction = self.decode(latents, embeddings)
222
+
223
+ # Estimate pitch using transcription switch
224
+ transcription = self.decode(latents, embeddings, True)
225
+
226
+ if consistency:
227
+ # Encode transcription coefficients for samples with ground-truth
228
+ latents_trn, embeddings_trn, _ = self.encoder(transcription)
229
+
230
+ # Apply skip connections if they are turned on
231
+ embeddings_trn = self.apply_skip_connections(embeddings_trn)
232
+
233
+ # Attempt to reconstruct transcription spectral coefficients
234
+ transcription_rec = self.decode(latents_trn, embeddings_trn)
235
+
236
+ # Attempt to transcribe audio pertaining to transcription coefficients
237
+ transcription_scr = self.decode(latents_trn, embeddings_trn, True)
238
+ else:
239
+ # Return null for both sets of coefficients
240
+ transcription_rec, transcription_scr = None, None
241
+
242
+ return reconstruction, latents, transcription, transcription_rec, transcription_scr, losses
243
+
244
+
245
+ class Encoder(nn.Module):
246
+ """
247
+ Implements a 2D convolutional encoder.
248
+ """
249
+
250
+ def __init__(self, feature_size, latent_size=None, model_complexity=1):
251
+ """
252
+ Initialize the encoder.
253
+
254
+ Parameters
255
+ ----------
256
+ feature_size : int
257
+ Dimensionality of input features
258
+ latent_size : int or None (Optional)
259
+ Dimensionality of latent space
260
+ model_complexity : int
261
+ Scaling factor for number of filters
262
+ """
263
+
264
+ nn.Module.__init__(self)
265
+
266
+ channels = (2 * 2 ** (model_complexity - 1),
267
+ 4 * 2 ** (model_complexity - 1),
268
+ 8 * 2 ** (model_complexity - 1),
269
+ 16 * 2 ** (model_complexity - 1),
270
+ 32 * 2 ** (model_complexity - 1))
271
+
272
+ # Make sure all channel sizes are integers
273
+ channels = tuple([round(c) for c in channels])
274
+
275
+ if latent_size is None:
276
+ # Set default dimensionality
277
+ latent_size = 32 * 2 ** (model_complexity - 1)
278
+
279
+ self.convin = nn.Sequential(
280
+ nn.Conv2d(2, channels[0], kernel_size=3, padding='same'),
281
+ nn.ELU(inplace=True)
282
+ )
283
+
284
+ self.block1 = EncoderBlock(channels[0], channels[1], stride=2)
285
+ self.block2 = EncoderBlock(channels[1], channels[2], stride=2)
286
+ self.block3 = EncoderBlock(channels[2], channels[3], stride=2)
287
+ self.block4 = EncoderBlock(channels[3], channels[4], stride=2)
288
+
289
+ embedding_size = feature_size
290
+
291
+ for i in range(4):
292
+ # Dimensionality after strided convolutions
293
+ embedding_size = embedding_size // 2 - 1
294
+
295
+ self.convlat = nn.Conv2d(channels[4], latent_size, kernel_size=(embedding_size, 1))
296
+
297
+ def forward(self, coefficients):
298
+ """
299
+ Encode a batch of input spectral features.
300
+
301
+ Parameters
302
+ ----------
303
+ coefficients : Tensor (B x 2 x F X T)
304
+ Batch of input spectral features
305
+
306
+ Returns
307
+ ----------
308
+ latents : Tensor (B x D_lat x T)
309
+ Batch of latent codes
310
+ embeddings : list of [Tensor (B x C x H x T)]
311
+ Embeddings produced by encoder at each level
312
+ losses : dict containing
313
+ ...
314
+ """
315
+
316
+ # Initialize a list to hold features for skip connections
317
+ embeddings = list()
318
+
319
+ # Encode features into embeddings
320
+ embeddings.append(self.convin(coefficients))
321
+ embeddings.append(self.block1(embeddings[-1]))
322
+ embeddings.append(self.block2(embeddings[-1]))
323
+ embeddings.append(self.block3(embeddings[-1]))
324
+ embeddings.append(self.block4(embeddings[-1]))
325
+
326
+ # Compute latent vectors from embeddings
327
+ latents = self.convlat(embeddings[-1]).squeeze(-2)
328
+
329
+ # No encoder losses
330
+ loss = dict()
331
+
332
+ return latents, embeddings, loss
333
+
334
+
335
+ class Decoder(nn.Module):
336
+ """
337
+ Implements a 2D convolutional decoder.
338
+ """
339
+
340
+ def __init__(self, feature_size, latent_size=None, model_complexity=1):
341
+ """
342
+ Initialize the decoder.
343
+
344
+ Parameters
345
+ ----------
346
+ feature_size : int
347
+ Dimensionality of input features
348
+ latent_size : int or None (Optional)
349
+ Dimensionality of latent space
350
+ model_complexity : int
351
+ Scaling factor for number of filters
352
+ """
353
+
354
+ nn.Module.__init__(self)
355
+
356
+ channels = (32 * 2 ** (model_complexity - 1),
357
+ 16 * 2 ** (model_complexity - 1),
358
+ 8 * 2 ** (model_complexity - 1),
359
+ 4 * 2 ** (model_complexity - 1),
360
+ 2 * 2 ** (model_complexity - 1))
361
+
362
+ # Make sure all channel sizes are integers
363
+ channels = tuple([round(c) for c in channels])
364
+
365
+ if latent_size is None:
366
+ # Set default dimensionality
367
+ latent_size = 32 * 2 ** (model_complexity - 1)
368
+
369
+ padding = list()
370
+
371
+ embedding_size = feature_size
372
+
373
+ for i in range(4):
374
+ # Padding required for expected output size
375
+ padding.append(embedding_size % 2)
376
+ # Dimensionality after strided convolutions
377
+ embedding_size = embedding_size // 2 - 1
378
+
379
+ # Reverse order
380
+ padding.reverse()
381
+
382
+ self.convin = nn.Sequential(
383
+ nn.ConvTranspose2d(latent_size + 1, channels[0], kernel_size=(embedding_size, 1)),
384
+ nn.ELU(inplace=True)
385
+ )
386
+
387
+ self.block1 = DecoderBlock(channels[0], channels[1], stride=2, padding=padding[0])
388
+ self.block2 = DecoderBlock(channels[1], channels[2], stride=2, padding=padding[1])
389
+ self.block3 = DecoderBlock(channels[2], channels[3], stride=2, padding=padding[2])
390
+ self.block4 = DecoderBlock(channels[3], channels[4], stride=2, padding=padding[3])
391
+
392
+ self.convout = nn.Conv2d(channels[4], 2, kernel_size=3, padding='same')
393
+
394
+ def forward(self, latents, encoder_embeddings=None):
395
+ """
396
+ Decode a batch of input latent codes.
397
+
398
+ Parameters
399
+ ----------
400
+ latents : Tensor (B x D_lat x T)
401
+ Batch of latent codes
402
+ encoder_embeddings : list of [Tensor (B x C x H x T)] or None (no skip connections)
403
+ Embeddings produced by encoder at each level
404
+
405
+ Returns
406
+ ----------
407
+ output : Tensor (B x 2 x F X T)
408
+ Batch of output logits [-∞, ∞]
409
+ """
410
+
411
+ # Restore feature dimension
412
+ latents = latents.unsqueeze(-2)
413
+
414
+ # Process latents with decoder blocks
415
+ embeddings = self.convin(latents)
416
+
417
+ if encoder_embeddings is not None:
418
+ embeddings = embeddings + encoder_embeddings[-1]
419
+
420
+ embeddings = self.block1(embeddings)
421
+
422
+ if encoder_embeddings is not None:
423
+ embeddings = embeddings + encoder_embeddings[-2]
424
+
425
+ embeddings = self.block2(embeddings)
426
+
427
+ if encoder_embeddings is not None:
428
+ embeddings = embeddings + encoder_embeddings[-3]
429
+
430
+ embeddings = self.block3(embeddings)
431
+
432
+ if encoder_embeddings is not None:
433
+ embeddings = embeddings + encoder_embeddings[-4]
434
+
435
+ embeddings = self.block4(embeddings)
436
+
437
+ if encoder_embeddings is not None:
438
+ embeddings = embeddings + encoder_embeddings[-5]
439
+
440
+ # Decode embeddings into spectral logits
441
+ output = self.convout(embeddings)
442
+
443
+ return output
444
+
445
+
446
+ class EncoderBlock(nn.Module):
447
+ """
448
+ Implements a chain of residual convolutional blocks with progressively
449
+ increased dilation, followed by down-sampling via strided convolution.
450
+ """
451
+
452
+ def __init__(self, in_channels, out_channels, stride=2):
453
+ """
454
+ Initialize the encoder block.
455
+
456
+ Parameters
457
+ ----------
458
+ in_channels : int
459
+ Number of input feature channels
460
+ out_channels : int
461
+ Number of output feature channels
462
+ stride : int
463
+ Stride for the final convolutional layer
464
+ """
465
+
466
+ nn.Module.__init__(self)
467
+
468
+ self.block1 = ResidualConv2dBlock(in_channels, in_channels, kernel_size=3, dilation=1)
469
+ self.block2 = ResidualConv2dBlock(in_channels, in_channels, kernel_size=3, dilation=2)
470
+ self.block3 = ResidualConv2dBlock(in_channels, in_channels, kernel_size=3, dilation=3)
471
+
472
+ self.hop = stride
473
+ self.win = 2 * stride
474
+
475
+ self.sconv = nn.Sequential(
476
+ # Down-sample along frequency (height) dimension via strided convolution
477
+ nn.Conv2d(in_channels, out_channels, kernel_size=(self.win, 1), stride=(self.hop, 1)),
478
+ nn.ELU(inplace=True)
479
+ )
480
+
481
+ def forward(self, x):
482
+ """
483
+ Feed features through the encoder block.
484
+
485
+ Parameters
486
+ ----------
487
+ x : Tensor (B x C_in x H x W)
488
+ Batch of input features
489
+
490
+ Returns
491
+ ----------
492
+ y : Tensor (B x C_out x H x W)
493
+ Batch of corresponding output features
494
+ """
495
+
496
+ # Process features
497
+ y = self.block1(x)
498
+ y = self.block2(y)
499
+ y = self.block3(y)
500
+
501
+ # Down-sample
502
+ y = self.sconv(y)
503
+
504
+ return y
505
+
506
+
507
+ class DecoderBlock(nn.Module):
508
+ """
509
+ Implements up-sampling via transposed convolution, followed by a chain
510
+ of residual convolutional blocks with progressively increased dilation.
511
+ """
512
+
513
+ def __init__(self, in_channels, out_channels, stride=2, padding=0):
514
+ """
515
+ Initialize the encoder block.
516
+
517
+ Parameters
518
+ ----------
519
+ in_channels : int
520
+ Number of input feature channels
521
+ out_channels : int
522
+ Number of output feature channels
523
+ stride : int
524
+ Stride for the transposed convolution
525
+ padding : int
526
+ Number of features to pad after up-sampling
527
+ """
528
+
529
+ nn.Module.__init__(self)
530
+
531
+ self.hop = stride
532
+ self.win = 2 * stride
533
+
534
+ self.tconv = nn.Sequential(
535
+ # Up-sample along frequency (height) dimension via transposed convolution
536
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=(self.win, 1), stride=(self.hop, 1), output_padding=(padding, 0)),
537
+ nn.ELU(inplace=True)
538
+ )
539
+
540
+ self.block1 = ResidualConv2dBlock(out_channels, out_channels, kernel_size=3, dilation=1)
541
+ self.block2 = ResidualConv2dBlock(out_channels, out_channels, kernel_size=3, dilation=2)
542
+ self.block3 = ResidualConv2dBlock(out_channels, out_channels, kernel_size=3, dilation=3)
543
+
544
+ def forward(self, x):
545
+ """
546
+ Feed features through the decoder block.
547
+
548
+ Parameters
549
+ ----------
550
+ x : Tensor (B x C_in x H x W)
551
+ Batch of input features
552
+
553
+ Returns
554
+ ----------
555
+ y : Tensor (B x C_out x H x W)
556
+ Batch of corresponding output features
557
+ """
558
+
559
+ # Up-sample
560
+ y = self.tconv(x)
561
+
562
+ # Process features
563
+ y = self.block1(y)
564
+ y = self.block2(y)
565
+ y = self.block3(y)
566
+
567
+ return y
568
+
569
+
570
+ class ResidualConv2dBlock(nn.Module):
571
+ """
572
+ Implements a 2D convolutional block with dilation, no down-sampling, and a residual connection.
573
+ """
574
+
575
+ def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
576
+ """
577
+ Initialize the convolutional block.
578
+
579
+ Parameters
580
+ ----------
581
+ in_channels : int
582
+ Number of input feature channels
583
+ out_channels : int
584
+ Number of output feature channels
585
+ kernel_size : int
586
+ Kernel size for convolutions
587
+ dilation : int
588
+ Amount of dilation for first convolution
589
+ """
590
+
591
+ nn.Module.__init__(self)
592
+
593
+ self.conv1 = nn.Sequential(
594
+ # TODO - only dilate across frequency?
595
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding='same', dilation=dilation),
596
+ nn.ELU(inplace=True)
597
+ )
598
+
599
+ self.conv2 = nn.Sequential(
600
+ nn.Conv2d(out_channels, out_channels, kernel_size=1),
601
+ nn.ELU(inplace=True)
602
+ )
603
+
604
+ def forward(self, x):
605
+ """
606
+ Feed features through the convolutional block.
607
+
608
+ Parameters
609
+ ----------
610
+ x : Tensor (B x C_in x H x W)
611
+ Batch of input features
612
+
613
+ Returns
614
+ ----------
615
+ y : Tensor (B x C_out x H x W)
616
+ Batch of corresponding output features
617
+ """
618
+
619
+ # Process features
620
+ y = self.conv1(x)
621
+ y = self.conv2(y)
622
+
623
+ # Residual connection
624
+ y = y + x
625
+
626
+ return y
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/audacitorch/pyharp.git#egg=pyharp
2
+ #git+https://github.com/sony/timbre-trap@main
3
+ torchaudio
4
+ torch
5
+ cqt_pytorch
6
+ librosa