bandhav commited on
Commit
e6a6383
1 Parent(s): 0e59911
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import wget
5
+ import torch
6
+ import torchaudio
7
+ import gradio as gr
8
+
9
+ from src.helpers import utils
10
+ from src.training.dcc_tf import Net as Waveformer
11
+
12
+ TARGETS = [
13
+ "Acoustic_guitar", "Applause", "Bark", "Bass_drum",
14
+ "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
15
+ "Computer_keyboard", "Cough", "Cowbell", "Double_bass",
16
+ "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
17
+ "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
18
+ "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
19
+ "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
20
+ "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
21
+ "Trumpet", "Violin_or_fiddle", "Writing"
22
+ ]
23
+
24
+ if not os.path.exists('default_config.json'):
25
+ config_url = 'https://targetsound.cs.washington.edu/files/default_config.json'
26
+ print("Downloading model configuration from %s:" % config_url)
27
+ wget.download(config_url)
28
+
29
+ if not os.path.exists('default_ckpt.pt'):
30
+ ckpt_url = 'https://targetsound.cs.washington.edu/files/default_ckpt.pt'
31
+ print("\nDownloading the checkpoint from %s:" % ckpt_url)
32
+ wget.download(ckpt_url)
33
+
34
+ # Instantiate model
35
+ params = utils.Params('default_config.json')
36
+ model = Waveformer(**params.model_params)
37
+ utils.load_checkpoint('default_ckpt.pt', model)
38
+ model.eval()
39
+
40
+ def waveformer(audio, label_choices):
41
+ # Read input audio
42
+ fs, mixture = audio
43
+ if fs != 44100:
44
+ raise ValueError(fs)
45
+ mixture = torch.from_numpy(mixture).unsqueeze(0)
46
+
47
+ # Construct the query vector
48
+ if len(label_choices) == 0:
49
+ raise ValueError(label_choices)
50
+ query = torch.zeros(1, len(TARGETS))
51
+ for t in label_choices:
52
+ query[0, TARGETS.index(t)] = 1.
53
+
54
+ with torch.no_grad():
55
+ output = model(mixture, query)
56
+
57
+ return fs, output.squeeze(0).numpy()
58
+
59
+
60
+ label_checkbox = gr.CheckboxGroup(choices=TARGETS)
61
+ demo = gr.Interface(fn=waveformer, inputs=['audio', label_checkbox], outputs="audio")
62
+ demo.launch()
default_config.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "src.training.dcc_tf",
3
+ "model_params":
4
+ {
5
+ "label_len": 41,
6
+ "L": 32,
7
+ "enc_dim": 512,
8
+ "num_enc_layers": 10,
9
+ "dec_dim": 256,
10
+ "num_dec_layers": 1,
11
+ "dec_buf_len": 13,
12
+ "dec_chunk_size": 13,
13
+ "out_buf_len": 4,
14
+ "use_pos_enc": "true"
15
+ },
16
+ "train_data":
17
+ {
18
+ "input_dir": "data/FSDSoundScapes",
19
+ "dset": "train",
20
+ "sr": 44100,
21
+ "resample_rate": null,
22
+ "max_num_targets":3
23
+ },
24
+ "val_data":
25
+ {
26
+ "input_dir": "data/FSDSoundScapes",
27
+ "dset": "val",
28
+ "sr": 44100,
29
+ "resample_rate": null,
30
+ "max_num_targets":3
31
+ },
32
+ "test_data":
33
+ {
34
+ "input_dir": "data/FSDSoundScapes",
35
+ "dset": "test",
36
+ "sr": 44100,
37
+ "resample_rate": null,
38
+ "max_num_targets":3
39
+ },
40
+ "optim":
41
+ {
42
+ "lr": 5e-4,
43
+ "weight_decay": 0.0
44
+ },
45
+ "lr_sched":
46
+ {
47
+ "mode": "max",
48
+ "factor": 0.1,
49
+ "patience": 5,
50
+ "min_lr": 5e-6,
51
+ "threshold": 0.1,
52
+ "threshold_mode": "abs"
53
+ },
54
+ "base_metric": "scale_invariant_signal_noise_ratio",
55
+ "fix_lr_epochs": 50,
56
+ "epochs": 150,
57
+ "batch_size": 16,
58
+ "eval_batch_size": 64,
59
+ "n_workers": 16
60
+ }
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ ### Requirements
2
+ librosa
3
+ torch
4
+ torchaudio
5
+ soundfile
6
+ numpy
7
+ speechbrain
8
+ wget
9
+
src/__init__.py ADDED
File without changes
src/helpers/__init__.py ADDED
File without changes
src/helpers/utils.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A collection of useful helper functions"""
2
+
3
+ import os
4
+ import logging
5
+ import json
6
+
7
+ import torch
8
+ from torch.profiler import profile, record_function, ProfilerActivity
9
+ import pandas as pd
10
+ from torchmetrics.functional import(
11
+ scale_invariant_signal_noise_ratio as si_snr,
12
+ signal_noise_ratio as snr,
13
+ signal_distortion_ratio as sdr,
14
+ scale_invariant_signal_distortion_ratio as si_sdr)
15
+ import matplotlib.pyplot as plt
16
+
17
+ class Params():
18
+ """Class that loads hyperparameters from a json file.
19
+ Example:
20
+ ```
21
+ params = Params(json_path)
22
+ print(params.learning_rate)
23
+ params.learning_rate = 0.5 # change the value of learning_rate in params
24
+ ```
25
+ """
26
+
27
+ def __init__(self, json_path):
28
+ with open(json_path) as f:
29
+ params = json.load(f)
30
+ self.__dict__.update(params)
31
+
32
+ def save(self, json_path):
33
+ with open(json_path, 'w') as f:
34
+ json.dump(self.__dict__, f, indent=4)
35
+
36
+ def update(self, json_path):
37
+ """Loads parameters from json file"""
38
+ with open(json_path) as f:
39
+ params = json.load(f)
40
+ self.__dict__.update(params)
41
+
42
+ @property
43
+ def dict(self):
44
+ """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
45
+ return self.__dict__
46
+
47
+ def save_graph(train_metrics, test_metrics, save_dir):
48
+ metrics = [snr, si_snr]
49
+ results = {'train_loss': train_metrics['loss'],
50
+ 'test_loss' : test_metrics['loss']}
51
+
52
+ for m_fn in metrics:
53
+ results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
54
+ results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
55
+
56
+ results_pd = pd.DataFrame(results)
57
+
58
+ results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
59
+
60
+ fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
61
+ axs=[]
62
+ for i in temp_ax:
63
+ for j in i:
64
+ axs.append(j)
65
+
66
+ x = range(len(train_metrics['loss']))
67
+ axs[0].plot(x, train_metrics['loss'], label='train')
68
+ axs[0].plot(x, test_metrics['loss'], label='test')
69
+ axs[0].set(ylabel='Loss')
70
+ axs[0].set(xlabel='Epoch')
71
+ axs[0].set_title('loss',fontweight='bold')
72
+ axs[0].legend()
73
+
74
+ for i in range(len(metrics)):
75
+ axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
76
+ axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
77
+ axs[i+1].set(xlabel='Epoch')
78
+ axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
79
+ axs[i+1].legend()
80
+
81
+ plt.tight_layout()
82
+ plt.savefig(os.path.join(save_dir, 'results.png'))
83
+ plt.close(fig)
84
+
85
+ def set_logger(log_path):
86
+ """Set the logger to log info in terminal and file `log_path`.
87
+ In general, it is useful to have a logger so that every output to the terminal is saved
88
+ in a permanent file. Here we save it to `model_dir/train.log`.
89
+ Example:
90
+ ```
91
+ logging.info("Starting training...")
92
+ ```
93
+ Args:
94
+ log_path: (string) where to log
95
+ """
96
+ logger = logging.getLogger()
97
+ logger.setLevel(logging.INFO)
98
+ logger.handlers.clear()
99
+
100
+ # Logging to a file
101
+ file_handler = logging.FileHandler(log_path)
102
+ file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
103
+ logger.addHandler(file_handler)
104
+
105
+ # Logging to console
106
+ stream_handler = logging.StreamHandler()
107
+ stream_handler.setFormatter(logging.Formatter('%(message)s'))
108
+ logger.addHandler(stream_handler)
109
+
110
+ def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
111
+ """Loads model parameters (state_dict) from file_path.
112
+
113
+ Args:
114
+ checkpoint: (string) filename which needs to be loaded
115
+ model: (torch.nn.Module) model for which the parameters are loaded
116
+ data_parallel: (bool) if the model is a data parallel model
117
+ """
118
+ if not os.path.exists(checkpoint):
119
+ raise("File doesn't exist {}".format(checkpoint))
120
+
121
+ state_dict = torch.load(checkpoint)
122
+
123
+ if data_parallel:
124
+ state_dict['model_state_dict'] = {
125
+ 'module.' + k: state_dict['model_state_dict'][k]
126
+ for k in state_dict['model_state_dict'].keys()}
127
+ model.load_state_dict(state_dict['model_state_dict'])
128
+
129
+ if optim is not None:
130
+ optim.load_state_dict(state_dict['optim_state_dict'])
131
+
132
+ if lr_sched is not None:
133
+ lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
134
+
135
+ return state_dict['epoch'], state_dict['train_metrics'], \
136
+ state_dict['val_metrics']
137
+
138
+ def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
139
+ train_metrics=None, val_metrics=None, data_parallel=False):
140
+ """Saves model parameters (state_dict) to file_path.
141
+
142
+ Args:
143
+ checkpoint: (string) filename which needs to be loaded
144
+ model: (torch.nn.Module) model for which the parameters are loaded
145
+ data_parallel: (bool) if the model is a data parallel model
146
+ """
147
+ if os.path.exists(checkpoint):
148
+ raise("File already exists {}".format(checkpoint))
149
+
150
+ model_state_dict = model.state_dict()
151
+ if data_parallel:
152
+ model_state_dict = {
153
+ k.partition('module.')[2]:
154
+ model_state_dict[k] for k in model_state_dict.keys()}
155
+
156
+ optim_state_dict = None if not optim else optim.state_dict()
157
+ lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
158
+
159
+ state_dict = {
160
+ 'epoch': epoch,
161
+ 'model_state_dict': model_state_dict,
162
+ 'optim_state_dict': optim_state_dict,
163
+ 'lr_sched_state_dict': lr_sched_state_dict,
164
+ 'train_metrics': train_metrics,
165
+ 'val_metrics': val_metrics
166
+ }
167
+
168
+ torch.save(state_dict, checkpoint)
169
+
170
+ def model_size(model):
171
+ """
172
+ Returns size of the `model` in millions of parameters.
173
+ """
174
+ num_train_params = sum(
175
+ p.numel() for p in model.parameters() if p.requires_grad)
176
+ return num_train_params / 1e6
177
+
178
+ def run_time(model, inputs, profiling=False):
179
+ """
180
+ Returns runtime of a model in ms.
181
+ """
182
+ # Warmup
183
+ for _ in range(100):
184
+ output = model(*inputs)
185
+
186
+ with profile(activities=[ProfilerActivity.CPU],
187
+ record_shapes=True) as prof:
188
+ with record_function("model_inference"):
189
+ output = model(*inputs)
190
+
191
+ # Print profiling results
192
+ if profiling:
193
+ print(prof.key_averages().table(sort_by="self_cpu_time_total",
194
+ row_limit=20))
195
+
196
+ # Return runtime in ms
197
+ return prof.profiler.self_cpu_time_total / 1000
198
+
199
+ def format_lr_info(optimizer):
200
+ lr_info = ""
201
+ for i, pg in enumerate(optimizer.param_groups):
202
+ lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
203
+ i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
204
+ return lr_info
205
+
src/training/__init__.py ADDED
File without changes
src/training/dcc_tf.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from collections import OrderedDict
3
+ from typing import Optional
4
+
5
+ from torch import Tensor
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ from torchmetrics.functional import(
11
+ scale_invariant_signal_noise_ratio as si_snr,
12
+ signal_noise_ratio as snr,
13
+ signal_distortion_ratio as sdr,
14
+ scale_invariant_signal_distortion_ratio as si_sdr)
15
+
16
+ from speechbrain.lobes.models.transformer.Transformer import PositionalEncoding
17
+
18
+ def mod_pad(x, chunk_size, pad):
19
+ # Mod pad the input to perform integer number of
20
+ # inferences
21
+ mod = 0
22
+ if (x.shape[-1] % chunk_size) != 0:
23
+ mod = chunk_size - (x.shape[-1] % chunk_size)
24
+
25
+ x = F.pad(x, (0, mod))
26
+ x = F.pad(x, pad)
27
+
28
+ return x, mod
29
+
30
+ class LayerNormPermuted(nn.LayerNorm):
31
+ def __init__(self, *args, **kwargs):
32
+ super(LayerNormPermuted, self).__init__(*args, **kwargs)
33
+
34
+ def forward(self, x):
35
+ """
36
+ Args:
37
+ x: [B, C, T]
38
+ """
39
+ x = x.permute(0, 2, 1) # [B, T, C]
40
+ x = super().forward(x)
41
+ x = x.permute(0, 2, 1) # [B, C, T]
42
+ return x
43
+
44
+ class DepthwiseSeparableConv(nn.Module):
45
+ """
46
+ Depthwise separable convolutions
47
+ """
48
+ def __init__(self, in_channels, out_channels, kernel_size, stride,
49
+ padding, dilation):
50
+ super(DepthwiseSeparableConv, self).__init__()
51
+
52
+ self.layers = nn.Sequential(
53
+ nn.Conv1d(in_channels, in_channels, kernel_size, stride,
54
+ padding, groups=in_channels, dilation=dilation),
55
+ LayerNormPermuted(in_channels),
56
+ nn.ReLU(),
57
+ nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1,
58
+ padding=0),
59
+ LayerNormPermuted(out_channels),
60
+ nn.ReLU(),
61
+ )
62
+
63
+ def forward(self, x):
64
+ return self.layers(x)
65
+
66
+ class DilatedCausalConvEncoder(nn.Module):
67
+ """
68
+ A dilated causal convolution based encoder for encoding
69
+ time domain audio input into latent space.
70
+ """
71
+ def __init__(self, channels, num_layers, kernel_size=3):
72
+ super(DilatedCausalConvEncoder, self).__init__()
73
+ self.channels = channels
74
+ self.num_layers = num_layers
75
+ self.kernel_size = kernel_size
76
+
77
+ # Compute buffer lengths for each layer
78
+ # buf_length[i] = (kernel_size - 1) * dilation[i]
79
+ self.buf_lengths = [(kernel_size - 1) * 2**i
80
+ for i in range(num_layers)]
81
+
82
+ # Compute buffer start indices for each layer
83
+ self.buf_indices = [0]
84
+ for i in range(num_layers - 1):
85
+ self.buf_indices.append(
86
+ self.buf_indices[-1] + self.buf_lengths[i])
87
+
88
+ # Dilated causal conv layers aggregate previous context to obtain
89
+ # contexful encoded input.
90
+ _dcc_layers = OrderedDict()
91
+ for i in range(num_layers):
92
+ dcc_layer = DepthwiseSeparableConv(
93
+ channels, channels, kernel_size=3, stride=1,
94
+ padding=0, dilation=2**i)
95
+ _dcc_layers.update({'dcc_%d' % i: dcc_layer})
96
+ self.dcc_layers = nn.Sequential(_dcc_layers)
97
+
98
+ def init_ctx_buf(self, batch_size, device):
99
+ """
100
+ Returns an initialized context buffer for a given batch size.
101
+ """
102
+ return torch.zeros(
103
+ (batch_size, self.channels,
104
+ (self.kernel_size - 1) * (2**self.num_layers - 1)),
105
+ device=device)
106
+
107
+ def forward(self, x, ctx_buf):
108
+ """
109
+ Encodes input audio `x` into latent space, and aggregates
110
+ contextual information in `ctx_buf`. Also generates new context
111
+ buffer with updated context.
112
+ Args:
113
+ x: [B, in_channels, T]
114
+ Input multi-channel audio.
115
+ ctx_buf: {[B, channels, self.buf_length[0]], ...}
116
+ A list of tensors holding context for each dilation
117
+ causal conv layer. (len(ctx_buf) == self.num_layers)
118
+ Returns:
119
+ ctx_buf: {[B, channels, self.buf_length[0]], ...}
120
+ Updated context buffer with output as the
121
+ last element.
122
+ """
123
+ T = x.shape[-1] # Sequence length
124
+
125
+ for i in range(self.num_layers):
126
+ buf_start_idx = self.buf_indices[i]
127
+ buf_end_idx = self.buf_indices[i] + self.buf_lengths[i]
128
+
129
+ # DCC input: concatenation of current output and context
130
+ dcc_in = torch.cat(
131
+ (ctx_buf[..., buf_start_idx:buf_end_idx], x), dim=-1)
132
+
133
+ # Push current output to the context buffer
134
+ ctx_buf[..., buf_start_idx:buf_end_idx] = \
135
+ dcc_in[..., -self.buf_lengths[i]:]
136
+
137
+ # Residual connection
138
+ x = x + self.dcc_layers[i](dcc_in)
139
+
140
+ return x, ctx_buf
141
+
142
+ class CausalTransformerDecoderLayer(torch.nn.TransformerDecoderLayer):
143
+ """
144
+ Adapted from:
145
+ "https://github.com/alexmt-scale/causal-transformer-decoder/blob/"
146
+ "0caf6ad71c46488f76d89845b0123d2550ef792f/"
147
+ "causal_transformer_decoder/model.py#L77"
148
+ """
149
+ def forward(
150
+ self,
151
+ tgt: Tensor,
152
+ memory: Optional[Tensor] = None,
153
+ chunk_size: int = 1
154
+ ) -> Tensor:
155
+ tgt_last_tok = tgt[:, -chunk_size:, :]
156
+
157
+ # self attention part
158
+ tmp_tgt, sa_map = self.self_attn(
159
+ tgt_last_tok,
160
+ tgt,
161
+ tgt,
162
+ attn_mask=None, # not needed because we only care about the last token
163
+ key_padding_mask=None,
164
+ )
165
+ tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
166
+ tgt_last_tok = self.norm1(tgt_last_tok)
167
+
168
+ # encoder-decoder attention
169
+ if memory is not None:
170
+ tmp_tgt, ca_map = self.multihead_attn(
171
+ tgt_last_tok,
172
+ memory,
173
+ memory,
174
+ attn_mask=None, # Attend to the entire chunk
175
+ key_padding_mask=None,
176
+ )
177
+ tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
178
+ tgt_last_tok = self.norm2(tgt_last_tok)
179
+
180
+ # final feed-forward network
181
+ tmp_tgt = self.linear2(
182
+ self.dropout(self.activation(self.linear1(tgt_last_tok)))
183
+ )
184
+ tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
185
+ tgt_last_tok = self.norm3(tgt_last_tok)
186
+ return tgt_last_tok, sa_map, ca_map
187
+
188
+ class CausalTransformerDecoder(nn.Module):
189
+ """
190
+ A casual transformer decoder which decodes input vectors using
191
+ precisely `ctx_len` past vectors in the sequence, and using no future
192
+ vectors at all.
193
+ """
194
+ def __init__(self, model_dim, ctx_len, chunk_size, num_layers,
195
+ nhead, use_pos_enc, ff_dim):
196
+ super(CausalTransformerDecoder, self).__init__()
197
+ self.num_layers = num_layers
198
+ self.model_dim = model_dim
199
+ self.ctx_len = ctx_len
200
+ self.chunk_size = chunk_size
201
+ self.nhead = nhead
202
+ self.use_pos_enc = use_pos_enc
203
+ self.unfold = nn.Unfold(kernel_size=(ctx_len + chunk_size, 1), stride=chunk_size)
204
+ self.pos_enc = PositionalEncoding(model_dim, max_len=200)
205
+ self.tf_dec_layers = nn.ModuleList([CausalTransformerDecoderLayer(
206
+ d_model=model_dim, nhead=nhead, dim_feedforward=ff_dim,
207
+ batch_first=True) for _ in range(num_layers)])
208
+
209
+ def init_ctx_buf(self, batch_size, device):
210
+ return torch.zeros(
211
+ (batch_size, self.num_layers + 1, self.ctx_len, self.model_dim),
212
+ device=device)
213
+
214
+ def _causal_unfold(self, x):
215
+ """
216
+ Unfolds the sequence into a batch of sequences
217
+ prepended with `ctx_len` previous values.
218
+
219
+ Args:
220
+ x: [B, ctx_len + L, C]
221
+ ctx_len: int
222
+ Returns:
223
+ [B * L, ctx_len + 1, C]
224
+ """
225
+ B, T, C = x.shape
226
+ x = x.permute(0, 2, 1) # [B, C, ctx_len + L]
227
+ x = self.unfold(x.unsqueeze(-1)) # [B, C * (ctx_len + chunk_size), -1]
228
+ x = x.permute(0, 2, 1)
229
+ x = x.reshape(B, -1, C, self.ctx_len + self.chunk_size)
230
+ x = x.reshape(-1, C, self.ctx_len + self.chunk_size)
231
+ x = x.permute(0, 2, 1)
232
+ return x
233
+
234
+ def forward(self, tgt, mem, ctx_buf, probe=False):
235
+ """
236
+ Args:
237
+ x: [B, model_dim, T]
238
+ ctx_buf: [B, num_layers, model_dim, ctx_len]
239
+ """
240
+ mem, _ = mod_pad(mem, self.chunk_size, (0, 0))
241
+ tgt, mod = mod_pad(tgt, self.chunk_size, (0, 0))
242
+
243
+ # Input sequence length
244
+ B, C, T = tgt.shape
245
+
246
+ tgt = tgt.permute(0, 2, 1)
247
+ mem = mem.permute(0, 2, 1)
248
+
249
+ # Prepend mem with the context
250
+ mem = torch.cat((ctx_buf[:, 0, :, :], mem), dim=1)
251
+ ctx_buf[:, 0, :, :] = mem[:, -self.ctx_len:, :]
252
+ mem_ctx = self._causal_unfold(mem)
253
+ if self.use_pos_enc:
254
+ mem_ctx = mem_ctx + self.pos_enc(mem_ctx)
255
+
256
+ # Attention chunk size: required to ensure the model
257
+ # wouldn't trigger an out-of-memory error when working
258
+ # on long sequences.
259
+ K = 1000
260
+
261
+ for i, tf_dec_layer in enumerate(self.tf_dec_layers):
262
+ # Update the tgt with context
263
+ tgt = torch.cat((ctx_buf[:, i + 1, :, :], tgt), dim=1)
264
+ ctx_buf[:, i + 1, :, :] = tgt[:, -self.ctx_len:, :]
265
+
266
+ # Compute encoded output
267
+ tgt_ctx = self._causal_unfold(tgt)
268
+ if self.use_pos_enc and i == 0:
269
+ tgt_ctx = tgt_ctx + self.pos_enc(tgt_ctx)
270
+ tgt = torch.zeros_like(tgt_ctx)[:, -self.chunk_size:, :]
271
+ for i in range(int(math.ceil(tgt.shape[0] / K))):
272
+ tgt[i*K:(i+1)*K], _sa_map, _ca_map = tf_dec_layer(
273
+ tgt_ctx[i*K:(i+1)*K], mem_ctx[i*K:(i+1)*K],
274
+ self.chunk_size)
275
+ tgt = tgt.reshape(B, T, C)
276
+
277
+ tgt = tgt.permute(0, 2, 1)
278
+ if mod != 0:
279
+ tgt = tgt[..., :-mod]
280
+
281
+ return tgt, ctx_buf
282
+
283
+ class MaskNet(nn.Module):
284
+ def __init__(self, enc_dim, num_enc_layers, dec_dim, dec_buf_len,
285
+ dec_chunk_size, num_dec_layers, use_pos_enc, skip_connection, proj):
286
+ super(MaskNet, self).__init__()
287
+ self.skip_connection = skip_connection
288
+ self.proj = proj
289
+
290
+ # Encoder based on dilated causal convolutions.
291
+ self.encoder = DilatedCausalConvEncoder(channels=enc_dim,
292
+ num_layers=num_enc_layers)
293
+
294
+ # Project between encoder and decoder dimensions
295
+ self.proj_e2d_e = nn.Sequential(
296
+ nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0,
297
+ groups=dec_dim),
298
+ nn.ReLU())
299
+ self.proj_e2d_l = nn.Sequential(
300
+ nn.Conv1d(enc_dim, dec_dim, kernel_size=1, stride=1, padding=0,
301
+ groups=dec_dim),
302
+ nn.ReLU())
303
+ self.proj_d2e = nn.Sequential(
304
+ nn.Conv1d(dec_dim, enc_dim, kernel_size=1, stride=1, padding=0,
305
+ groups=dec_dim),
306
+ nn.ReLU())
307
+
308
+ # Transformer decoder that operates on chunks of size
309
+ # buffer size.
310
+ self.decoder = CausalTransformerDecoder(
311
+ model_dim=dec_dim, ctx_len=dec_buf_len, chunk_size=dec_chunk_size,
312
+ num_layers=num_dec_layers, nhead=8, use_pos_enc=use_pos_enc,
313
+ ff_dim=2 * dec_dim)
314
+
315
+ def forward(self, x, l, enc_buf, dec_buf):
316
+ """
317
+ Generates a mask based on encoded input `e` and the one-hot
318
+ label `label`.
319
+
320
+ Args:
321
+ x: [B, C, T]
322
+ Input audio sequence
323
+ l: [B, C]
324
+ Label embedding
325
+ ctx_buf: {[B, C, <receptive field of the layer>], ...}
326
+ List of context buffers maintained by DCC encoder
327
+ """
328
+ # Enocder the label integrated input
329
+ e, enc_buf = self.encoder(x, enc_buf)
330
+
331
+ # Label integration
332
+ l = l.unsqueeze(2) * e
333
+
334
+ # Project to `dec_dim` dimensions
335
+ if self.proj:
336
+ e = self.proj_e2d_e(e)
337
+ m = self.proj_e2d_l(l)
338
+ # Cross-attention to predict the mask
339
+ m, dec_buf = self.decoder(m, e, dec_buf)
340
+ else:
341
+ # Cross-attention to predict the mask
342
+ m, dec_buf = self.decoder(l, e, dec_buf)
343
+
344
+ # Project mask to encoder dimensions
345
+ if self.proj:
346
+ m = self.proj_d2e(m)
347
+
348
+ # Final mask after residual connection
349
+ if self.skip_connection:
350
+ m = l + m
351
+
352
+ return m, enc_buf, dec_buf
353
+
354
+ class Net(nn.Module):
355
+ def __init__(self, label_len, L=8,
356
+ enc_dim=512, num_enc_layers=10,
357
+ dec_dim=256, dec_buf_len=100, num_dec_layers=2,
358
+ dec_chunk_size=72, out_buf_len=2,
359
+ use_pos_enc=True, skip_connection=True, proj=True, lookahead=True):
360
+ super(Net, self).__init__()
361
+ self.L = L
362
+ self.out_buf_len = out_buf_len
363
+ self.enc_dim = enc_dim
364
+ self.lookahead = lookahead
365
+
366
+ # Input conv to convert input audio to a latent representation
367
+ kernel_size = 3 * L if lookahead else L
368
+ self.in_conv = nn.Sequential(
369
+ nn.Conv1d(in_channels=1,
370
+ out_channels=enc_dim, kernel_size=kernel_size, stride=L,
371
+ padding=0, bias=False),
372
+ nn.ReLU())
373
+
374
+ # Label embedding layer
375
+ self.label_embedding = nn.Sequential(
376
+ nn.Linear(label_len, 512),
377
+ nn.LayerNorm(512),
378
+ nn.ReLU(),
379
+ nn.Linear(512, enc_dim),
380
+ nn.LayerNorm(enc_dim),
381
+ nn.ReLU())
382
+
383
+ # Mask generator
384
+ self.mask_gen = MaskNet(
385
+ enc_dim=enc_dim, num_enc_layers=num_enc_layers,
386
+ dec_dim=dec_dim, dec_buf_len=dec_buf_len,
387
+ dec_chunk_size=dec_chunk_size, num_dec_layers=num_dec_layers,
388
+ use_pos_enc=use_pos_enc, skip_connection=skip_connection, proj=proj)
389
+
390
+ # Output conv layer
391
+ self.out_conv = nn.Sequential(
392
+ nn.ConvTranspose1d(
393
+ in_channels=enc_dim, out_channels=1,
394
+ kernel_size=(out_buf_len + 1) * L,
395
+ stride=L,
396
+ padding=out_buf_len * L, bias=False),
397
+ nn.Tanh())
398
+
399
+ def init_buffers(self, batch_size, device):
400
+ enc_buf = self.mask_gen.encoder.init_ctx_buf(batch_size, device)
401
+ dec_buf = self.mask_gen.decoder.init_ctx_buf(batch_size, device)
402
+ out_buf = torch.zeros(batch_size, self.enc_dim, self.out_buf_len,
403
+ device=device)
404
+ return enc_buf, dec_buf, out_buf
405
+
406
+ def forward(self, x, label, init_enc_buf=None, init_dec_buf=None,
407
+ init_out_buf=None, pad=True):
408
+ """
409
+ Extracts the audio corresponding to the `label` in the given
410
+ `mixture`. Generates `chunk_size` samples per iteration.
411
+
412
+ Args:
413
+ mixed: [B, n_mics, T]
414
+ input audio mixture
415
+ label: [B, num_labels]
416
+ one hot label
417
+ Returns:
418
+ out: [B, n_spk, T]
419
+ extracted audio with sounds corresponding to the `label`
420
+ """
421
+ mod = 0
422
+ if pad:
423
+ pad_size = (self.L, self.L) if self.lookahead else (0, 0)
424
+ x, mod = mod_pad(x, chunk_size=self.L, pad=pad_size)
425
+
426
+ if init_enc_buf is None or init_dec_buf is None or init_out_buf is None:
427
+ assert init_enc_buf is None and \
428
+ init_dec_buf is None and \
429
+ init_out_buf is None, \
430
+ "Both buffers have to initialized, or " \
431
+ "both of them have to be None."
432
+ enc_buf, dec_buf, out_buf = self.init_buffers(
433
+ x.shape[0], x.device)
434
+ else:
435
+ enc_buf, dec_buf, out_buf = \
436
+ init_enc_buf, init_dec_buf, init_out_buf
437
+
438
+ # Generate latent space representation of the input
439
+ x = self.in_conv(x)
440
+
441
+ # Generate label embedding
442
+ l = self.label_embedding(label) # [B, label_len] --> [B, channels]
443
+
444
+ # Generate mask corresponding to the label
445
+ m, enc_buf, dec_buf = self.mask_gen(x, l, enc_buf, dec_buf)
446
+
447
+ # Apply mask and decode
448
+ x = x * m
449
+ x = torch.cat((out_buf, x), dim=-1)
450
+ out_buf = x[..., -self.out_buf_len:]
451
+ x = self.out_conv(x)
452
+
453
+ # Remove mod padding, if present.
454
+ if mod != 0:
455
+ x = x[:, :, :-mod]
456
+
457
+ if init_enc_buf is None:
458
+ return x
459
+ else:
460
+ return x, enc_buf, dec_buf, out_buf
461
+
462
+ # Define optimizer, loss and metrics
463
+
464
+ def optimizer(model, data_parallel=False, **kwargs):
465
+ return optim.Adam(model.parameters(), **kwargs)
466
+
467
+ def loss(pred, tgt):
468
+ return -0.9 * snr(pred, tgt).mean() - 0.1 * si_snr(pred, tgt).mean()
469
+
470
+ def metrics(mixed, output, gt):
471
+ """ Function to compute metrics """
472
+ metrics = {}
473
+
474
+ def metric_i(metric, src, pred, tgt):
475
+ _vals = []
476
+ for s, t, p in zip(src, tgt, pred):
477
+ _vals.append((metric(p, t) - metric(s, t)).cpu().item())
478
+ return _vals
479
+
480
+ for m_fn in [snr, si_snr]:
481
+ metrics[m_fn.__name__] = metric_i(m_fn,
482
+ mixed[:, :gt.shape[1], :],
483
+ output,
484
+ gt)
485
+
486
+ return metrics
src/training/eval.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test script to evaluate the model.
3
+ """
4
+
5
+ import argparse
6
+ import importlib
7
+ import multiprocessing
8
+ import os, glob
9
+ import logging
10
+
11
+ import numpy as np
12
+ import torch
13
+ import pandas as pd
14
+ import torch.nn as nn
15
+ from torch.utils.tensorboard import SummaryWriter
16
+ from torch.profiler import profile, record_function, ProfilerActivity
17
+ from tqdm import tqdm # pylint: disable=unused-import
18
+ from torchmetrics.functional import(
19
+ scale_invariant_signal_noise_ratio as si_snr,
20
+ signal_noise_ratio as snr,
21
+ signal_distortion_ratio as sdr,
22
+ scale_invariant_signal_distortion_ratio as si_sdr)
23
+
24
+ from src.helpers import utils
25
+ from src.training.synthetic_dataset import FSDSoundScapesDataset, tensorboard_add_metrics
26
+ from src.training.synthetic_dataset import tensorboard_add_sample
27
+
28
+ def test_epoch(model: nn.Module, device: torch.device,
29
+ test_loader: torch.utils.data.dataloader.DataLoader,
30
+ n_items: int, loss_fn, metrics_fn,
31
+ profiling: bool = False, epoch: int = 0,
32
+ writer: SummaryWriter = None, data_params = None) -> float:
33
+ """
34
+ Evaluate the network.
35
+ """
36
+ model.eval()
37
+ metrics = {}
38
+
39
+ with torch.no_grad():
40
+ for batch_idx, (mixed, label, gt) in \
41
+ enumerate(tqdm(test_loader, desc='Test', ncols=100)):
42
+ mixed = mixed.to(device)
43
+ label = label.to(device)
44
+ gt = gt.to(device)
45
+
46
+ # Run through the model
47
+ with profile(activities=[ProfilerActivity.CPU],
48
+ record_shapes=True) as prof:
49
+ with record_function("model_inference"):
50
+ output = model(mixed, label)
51
+ if profiling:
52
+ logging.info(
53
+ prof.key_averages().table(sort_by="self_cpu_time_total",
54
+ row_limit=20))
55
+
56
+ # Compute loss
57
+ loss = loss_fn(output, gt)
58
+
59
+ # Compute metrics
60
+ metrics_batch = metrics_fn(mixed, output, gt)
61
+ metrics_batch['loss'] = [loss.item()]
62
+ metrics_batch['runtime'] = [prof.profiler.self_cpu_time_total/1000]
63
+ for k in metrics_batch.keys():
64
+ if not k in metrics:
65
+ metrics[k] = metrics_batch[k]
66
+ else:
67
+ metrics[k] += metrics_batch[k]
68
+
69
+ if writer is not None:
70
+ if batch_idx == 0:
71
+ tensorboard_add_sample(
72
+ writer, tag='Test',
73
+ sample=(mixed[:8], label[:8], gt[:8], output[:8]),
74
+ step=epoch, params=data_params)
75
+ tensorboard_add_metrics(
76
+ writer, tag='Test', metrics=metrics_batch, label=label,
77
+ step=epoch)
78
+
79
+ if n_items is not None and batch_idx == (n_items - 1):
80
+ break
81
+
82
+ avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
83
+ avg_metrics_str = "Test:"
84
+ for m in avg_metrics.keys():
85
+ avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
86
+ logging.info(avg_metrics_str)
87
+
88
+ return avg_metrics
89
+
90
+ def evaluate(network, args: argparse.Namespace):
91
+ """
92
+ Evaluate the model on a given dataset.
93
+ """
94
+
95
+ # Load dataset
96
+ data_test = FSDSoundScapesDataset(**args.test_data)
97
+ logging.info("Loaded test dataset at %s containing %d elements" %
98
+ (args.test_data['input_dir'], len(data_test)))
99
+
100
+ # Set up the device and workers.
101
+ use_cuda = args.use_cuda and torch.cuda.is_available()
102
+ if use_cuda:
103
+ gpu_ids = args.gpu_ids if args.gpu_ids is not None\
104
+ else range(torch.cuda.device_count())
105
+ device_ids = [_ for _ in gpu_ids]
106
+ data_parallel = len(device_ids) > 1
107
+ device = 'cuda:%d' % device_ids[0]
108
+ torch.cuda.set_device(device_ids[0])
109
+ logging.info("Using CUDA devices: %s" % str(device_ids))
110
+ else:
111
+ data_parallel = False
112
+ device = torch.device('cpu')
113
+ logging.info("Using device: CPU")
114
+
115
+ # Set multiprocessing params
116
+ num_workers = min(multiprocessing.cpu_count(), args.n_workers)
117
+ kwargs = {
118
+ 'num_workers': num_workers,
119
+ 'pin_memory': True
120
+ } if use_cuda else {}
121
+
122
+ # Set up data loader
123
+ test_loader = torch.utils.data.DataLoader(data_test,
124
+ batch_size=args.eval_batch_size,
125
+ **kwargs)
126
+
127
+ # Set up model
128
+ model = network.Net(**args.model_params)
129
+ if use_cuda and data_parallel:
130
+ model = nn.DataParallel(model, device_ids=device_ids)
131
+ logging.info("Using data parallel model")
132
+ model.to(device)
133
+
134
+ # Load weights
135
+ if args.pretrain_path == "best":
136
+ ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt'))
137
+ ckpts.sort(
138
+ key=lambda _: int(os.path.splitext(os.path.basename(_))[0]))
139
+ val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric]
140
+ best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__)
141
+ args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch)
142
+ logging.info(
143
+ "Found 'best' validation %s=%.02f at %s" %
144
+ (args.base_metric, val_metrics[best_epoch], args.pretrain_path))
145
+ if args.pretrain_path != "":
146
+ utils.load_checkpoint(
147
+ args.pretrain_path, model, data_parallel=data_parallel)
148
+ logging.info("Loaded pretrain weights from %s" % args.pretrain_path)
149
+
150
+ # Evaluate
151
+ try:
152
+ return test_epoch(
153
+ model, device, test_loader, args.n_items, network.loss,
154
+ network.metrics, args.profiling)
155
+ except KeyboardInterrupt:
156
+ print("Interrupted")
157
+ except Exception as _: # pylint: disable=broad-except
158
+ import traceback # pylint: disable=import-outside-toplevel
159
+ traceback.print_exc()
160
+
161
+
162
+ if __name__ == '__main__':
163
+ parser = argparse.ArgumentParser()
164
+ # Data Params
165
+ parser.add_argument('experiments', nargs='+', type=str,
166
+ default=None,
167
+ help="List of experiments to evaluate. "
168
+ "Provide only one experiment when providing "
169
+ "pretrained path. If pretrianed path is not "
170
+ "provided, epoch with best validation metric "
171
+ "is used for evaluation.")
172
+ parser.add_argument('--results', type=str, default="",
173
+ help="Path to the CSV file to store results.")
174
+
175
+ # System params
176
+ parser.add_argument('--n_items', type=int, default=None,
177
+ help="Number of items to test.")
178
+ parser.add_argument('--pretrain_path', type=str, default="best",
179
+ help="Path to pretrained weights")
180
+ parser.add_argument('--profiling', dest='profiling', action='store_true',
181
+ help="Enable or disable profiling.")
182
+ parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
183
+ help="Whether to use cuda")
184
+ parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
185
+ help="List of GPU ids used for training. "
186
+ "Eg., --gpu_ids 2 4. All GPUs are used by default.")
187
+ args = parser.parse_args()
188
+
189
+ results = []
190
+
191
+ for exp_dir in args.experiments:
192
+ eval_args = argparse.Namespace(**vars(args))
193
+ eval_args.exp_dir = exp_dir
194
+
195
+ utils.set_logger(os.path.join(exp_dir, 'eval.log'))
196
+ logging.info("Evaluating %s ..." % exp_dir)
197
+
198
+ # Load model and training params
199
+ params = utils.Params(os.path.join(exp_dir, 'config.json'))
200
+ for k, v in params.__dict__.items():
201
+ vars(eval_args)[k] = v
202
+
203
+ network = importlib.import_module(eval_args.model)
204
+ logging.info("Imported the model from '%s'." % eval_args.model)
205
+
206
+ curr_res = evaluate(network, eval_args)
207
+ curr_res['experiment'] = os.path.basename(exp_dir)
208
+ results.append(curr_res)
209
+
210
+ del eval_args
211
+
212
+ if args.results != "":
213
+ print("Writing results to %s" % args.results)
214
+ pd.DataFrame(results).to_csv(args.results, index=False)
src/training/synthetic_dataset.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Torch dataset object for synthetically rendered spatial data.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import random
8
+ from pathlib import Path
9
+ import logging
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import matplotlib.pyplot as plt
14
+ import scaper
15
+ import torch
16
+ import torchaudio
17
+ import torchaudio.transforms as AT
18
+ from random import randrange
19
+
20
+ class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore
21
+ """
22
+ Base class for FSD Sound Scapes dataset
23
+ """
24
+
25
+ _labels = [
26
+ "Acoustic_guitar", "Applause", "Bark", "Bass_drum",
27
+ "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
28
+ "Computer_keyboard", "Cough", "Cowbell", "Double_bass",
29
+ "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
30
+ "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
31
+ "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
32
+ "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
33
+ "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
34
+ "Trumpet", "Violin_or_fiddle", "Writing"]
35
+
36
+ def __init__(self, input_dir, dset='', sr=None,
37
+ resample_rate=None, max_num_targets=1):
38
+ assert dset in ['train', 'val', 'test'], \
39
+ "`dset` must be one of ['train', 'val', 'test']"
40
+ self.dset = dset
41
+ self.max_num_targets = max_num_targets
42
+ self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset)
43
+ if dset in ['train', 'val']:
44
+ self.bg_dir = os.path.join(
45
+ input_dir,
46
+ 'TAU-acoustic-sounds/'
47
+ 'TAU-urban-acoustic-scenes-2019-development')
48
+ else:
49
+ self.bg_dir = os.path.join(
50
+ input_dir,
51
+ 'TAU-acoustic-sounds/'
52
+ 'TAU-urban-acoustic-scenes-2019-evaluation')
53
+ logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" %
54
+ (dset, self.fg_dir, self.bg_dir))
55
+
56
+ self.samples = sorted(list(
57
+ Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*')))
58
+
59
+ jamsfile = os.path.join(self.samples[0], 'mixture.jams')
60
+ _, jams, _, _ = scaper.generate_from_jams(
61
+ jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
62
+ _sr = jams['annotations'][0]['sandbox']['scaper']['sr']
63
+ assert _sr == sr, "Sampling rate provided does not match the data"
64
+
65
+ if resample_rate is not None:
66
+ self.resampler = AT.Resample(sr, resample_rate)
67
+ self.sr = resample_rate
68
+ else:
69
+ self.resampler = lambda a: a
70
+ self.sr = sr
71
+
72
+ def _get_label_vector(self, labels):
73
+ """
74
+ Generates a multi-hot vector corresponding to `labels`.
75
+ """
76
+ vector = torch.zeros(len(FSDSoundScapesDataset._labels))
77
+
78
+ for label in labels:
79
+ idx = FSDSoundScapesDataset._labels.index(label)
80
+ assert vector[idx] == 0, "Repeated labels"
81
+ vector[idx] = 1
82
+
83
+ return vector
84
+
85
+ def __len__(self):
86
+ return len(self.samples)
87
+
88
+ def __getitem__(self, idx):
89
+ sample_path = self.samples[idx]
90
+ jamsfile = os.path.join(sample_path, 'mixture.jams')
91
+
92
+ mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams(
93
+ jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
94
+ isolated_events = {}
95
+ for e, a in zip(ann_list, event_audio_list[1:]):
96
+ # 0th event is background
97
+ isolated_events[e[2]] = a
98
+ gt_events = list(pd.read_csv(
99
+ os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label'])
100
+
101
+ mixture = torch.from_numpy(mixture).permute(1, 0)
102
+ mixture = self.resampler(mixture.to(torch.float))
103
+
104
+ if self.dset == 'train':
105
+ labels = random.sample(gt_events, randrange(1,self.max_num_targets+1))
106
+ elif self.dset == 'val':
107
+ labels = gt_events[:idx%self.max_num_targets+1]
108
+ elif self.dset == 'test':
109
+ labels = gt_events[:self.max_num_targets]
110
+ label_vector = self._get_label_vector(labels)
111
+
112
+ gt = torch.zeros_like(
113
+ torch.from_numpy(event_audio_list[1]).permute(1, 0))
114
+ for l in labels:
115
+ gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0)
116
+ gt = self.resampler(gt.to(torch.float))
117
+
118
+ return mixture, label_vector, gt #, jams
119
+
120
+ def tensorboard_add_sample(writer, tag, sample, step, params):
121
+ """
122
+ Adds a sample of FSDSynthDataset to tensorboard.
123
+ """
124
+ if params['resample_rate'] is not None:
125
+ sr = params['resample_rate']
126
+ else:
127
+ sr = params['sr']
128
+ resample_rate = 16000 if sr > 16000 else sr
129
+
130
+ m, l, gt, o = sample
131
+ m, gt, o = (
132
+ torchaudio.functional.resample(_, sr, resample_rate).cpu()
133
+ for _ in (m, gt, o))
134
+
135
+ def _add_audio(a, audio_tag, axis, plt_title):
136
+ for i, ch in enumerate(a):
137
+ axis.plot(ch, label='mic %d' % i)
138
+ writer.add_audio(
139
+ '%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate)
140
+ axis.set_title(plt_title)
141
+ axis.legend()
142
+
143
+ for b in range(m.shape[0]):
144
+ label = []
145
+ for i in range(len(l[b, :])):
146
+ if l[b, i] == 1:
147
+ label.append(FSDSoundScapesDataset._labels[i])
148
+
149
+ # Add waveforms
150
+ rows = 3 # input, output, gt
151
+ fig = plt.figure(figsize=(10, 2 * rows))
152
+ axes = fig.subplots(rows, 1, sharex=True)
153
+ _add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed")
154
+ _add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label)
155
+ _add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label)
156
+ writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step)
157
+
158
+ def tensorboard_add_metrics(writer, tag, metrics, label, step):
159
+ """
160
+ Add metrics to tensorboard.
161
+ """
162
+ vals = np.asarray(metrics['scale_invariant_signal_noise_ratio'])
163
+
164
+ writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step)
165
+
166
+ label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label]
167
+ for l, v in zip(label_names, vals):
168
+ writer.add_histogram('%s/%s' % (tag, l), v, step)
src/training/train.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The main training script for training on synthetic data
3
+ """
4
+
5
+ import argparse
6
+ import multiprocessing
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+ import random
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import torch.optim as optim
17
+ from torch.utils.tensorboard import SummaryWriter
18
+ from tqdm import tqdm # pylint: disable=unused-import
19
+ from torchmetrics.functional import(
20
+ scale_invariant_signal_noise_ratio as si_snr,
21
+ signal_noise_ratio as snr,
22
+ signal_distortion_ratio as sdr,
23
+ scale_invariant_signal_distortion_ratio as si_sdr)
24
+
25
+ from src.helpers import utils
26
+ from src.training.eval import test_epoch
27
+ from src.training.synthetic_dataset import FSDSoundScapesDataset as Dataset
28
+ from src.training.synthetic_dataset import tensorboard_add_sample
29
+
30
+ def train_epoch(model: nn.Module, device: torch.device,
31
+ optimizer: optim.Optimizer,
32
+ train_loader: torch.utils.data.dataloader.DataLoader,
33
+ n_items: int, epoch: int = 0,
34
+ writer: SummaryWriter = None, data_params = None) -> float:
35
+
36
+ """
37
+ Train a single epoch.
38
+ """
39
+ # Set the model to training.
40
+ model.train()
41
+
42
+ # Training loop
43
+ losses = []
44
+ metrics = {}
45
+
46
+ with tqdm(total=len(train_loader), desc='Train', ncols=100) as t:
47
+ for batch_idx, (mixed, label, gt) in enumerate(train_loader):
48
+ mixed = mixed.to(device)
49
+ label = label.to(device)
50
+ gt = gt.to(device)
51
+
52
+ # Reset grad
53
+ optimizer.zero_grad()
54
+
55
+ # Run through the model
56
+ output = model(mixed, label)
57
+
58
+ # Compute loss
59
+ loss = network.loss(output, gt)
60
+
61
+ losses.append(loss.item())
62
+
63
+ # Backpropagation
64
+ loss.backward()
65
+
66
+ # Gradient clipping
67
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
68
+
69
+ # Update the weights
70
+ optimizer.step()
71
+
72
+ metrics_batch = network.metrics(mixed.detach(), output.detach(),
73
+ gt.detach())
74
+ for k in metrics_batch.keys():
75
+ if not k in metrics:
76
+ metrics[k] = metrics_batch[k]
77
+ else:
78
+ metrics[k] += metrics_batch[k]
79
+
80
+ if writer is not None and batch_idx == 0:
81
+ tensorboard_add_sample(
82
+ writer, tag='Train',
83
+ sample=(mixed.detach()[:8], label.detach()[:8],
84
+ gt.detach()[:8], output.detach()[:8]),
85
+ step=epoch, params=data_params)
86
+
87
+ # Show current loss in the progress meter
88
+ t.set_postfix(loss='%.05f'%loss.item())
89
+ t.update()
90
+
91
+ if n_items is not None and batch_idx == n_items:
92
+ break
93
+
94
+ avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
95
+ avg_metrics['loss'] = np.mean(losses)
96
+ avg_metrics_str = "Train:"
97
+ for m in avg_metrics.keys():
98
+ avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
99
+ logging.info(avg_metrics_str)
100
+
101
+ return avg_metrics
102
+
103
+
104
+ def train(args: argparse.Namespace):
105
+ """
106
+ Train the network.
107
+ """
108
+
109
+ # Load dataset
110
+ data_train = Dataset(**args.train_data)
111
+ logging.info("Loaded train dataset at %s containing %d elements" %
112
+ (args.train_data['input_dir'], len(data_train)))
113
+ data_val = Dataset(**args.val_data)
114
+ logging.info("Loaded test dataset at %s containing %d elements" %
115
+ (args.val_data['input_dir'], len(data_val)))
116
+
117
+ # Set up the device and workers.
118
+ use_cuda = args.use_cuda and torch.cuda.is_available()
119
+ if use_cuda:
120
+ gpu_ids = args.gpu_ids if args.gpu_ids is not None\
121
+ else range(torch.cuda.device_count())
122
+ device_ids = [_ for _ in gpu_ids]
123
+ data_parallel = len(device_ids) > 1
124
+ device = 'cuda:%d' % device_ids[0]
125
+ torch.cuda.set_device(device_ids[0])
126
+ logging.info("Using CUDA devices: %s" % str(device_ids))
127
+ else:
128
+ data_parallel = False
129
+ device = torch.device('cpu')
130
+ logging.info("Using device: CPU")
131
+
132
+ # Set multiprocessing params
133
+ num_workers = min(multiprocessing.cpu_count(), args.n_workers)
134
+ kwargs = {
135
+ 'num_workers': num_workers,
136
+ 'pin_memory': True
137
+ } if use_cuda else {}
138
+
139
+ # Set up data loaders
140
+ #print(args.batch_size, args.eval_batch_size)
141
+ train_loader = torch.utils.data.DataLoader(data_train,
142
+ batch_size=args.batch_size,
143
+ shuffle=True, **kwargs)
144
+ val_loader = torch.utils.data.DataLoader(data_val,
145
+ batch_size=args.eval_batch_size,
146
+ **kwargs)
147
+
148
+ # Set up model
149
+ model = network.Net(**args.model_params)
150
+
151
+ # Add graph to tensorboard with example train samples
152
+ # _mixed, _label, _ = next(iter(val_loader))
153
+ # args.writer.add_graph(model, (_mixed, _label))
154
+
155
+ if use_cuda and data_parallel:
156
+ model = nn.DataParallel(model, device_ids=device_ids)
157
+ logging.info("Using data parallel model")
158
+ model.to(device)
159
+
160
+ # Set up the optimizer
161
+ logging.info("Initializing optimizer with %s" % str(args.optim))
162
+ optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel)
163
+ logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer))
164
+
165
+ lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
166
+ optimizer, **args.lr_sched)
167
+ logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s"
168
+ % (args.fix_lr_epochs, str(args.lr_sched)))
169
+
170
+ base_metric = args.base_metric
171
+ train_metrics = {}
172
+ val_metrics = {}
173
+
174
+ # Load the model if `args.start_epoch` is greater than 0. This will load the
175
+ # model from epoch = `args.start_epoch - 1`
176
+ assert args.start_epoch >=0, "start_epoch must be greater than 0."
177
+ if args.start_epoch > 0:
178
+ checkpoint_path = os.path.join(args.exp_dir,
179
+ '%d.pt' % (args.start_epoch - 1))
180
+ _, train_metrics, val_metrics = utils.load_checkpoint(
181
+ checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler,
182
+ data_parallel=data_parallel)
183
+ logging.info("Loaded checkpoint from %s" % checkpoint_path)
184
+ logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer))
185
+
186
+ # Training loop
187
+ try:
188
+ torch.autograd.set_detect_anomaly(args.detect_anomaly)
189
+ for epoch in range(args.start_epoch, args.epochs + 1):
190
+ logging.info("Epoch %d:" % epoch)
191
+ checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch)
192
+ assert not os.path.exists(checkpoint_file), \
193
+ "Checkpoint file %s already exists" % checkpoint_file
194
+ #print("---- begin trianivg")
195
+ curr_train_metrics = train_epoch(model, device, optimizer,
196
+ train_loader, args.n_train_items,
197
+ epoch=epoch, writer=args.writer,
198
+ data_params=args.train_data)
199
+ #raise KeyboardInterrupt
200
+ curr_test_metrics = test_epoch(model, device, val_loader,
201
+ args.n_test_items, network.loss,
202
+ network.metrics, epoch=epoch,
203
+ writer=args.writer,
204
+ data_params=args.val_data)
205
+ # LR scheduler
206
+ if epoch >= args.fix_lr_epochs:
207
+ lr_scheduler.step(curr_test_metrics[base_metric])
208
+ logging.info(
209
+ "LR after scheduling step: %s" %
210
+ [_['lr'] for _ in optimizer.param_groups])
211
+
212
+ # Write metrics to tensorboard
213
+ args.writer.add_scalars('Train', curr_train_metrics, epoch)
214
+ args.writer.add_scalars('Val', curr_test_metrics, epoch)
215
+ args.writer.flush()
216
+
217
+ for k in curr_train_metrics.keys():
218
+ if not k in train_metrics:
219
+ train_metrics[k] = [curr_train_metrics[k]]
220
+ else:
221
+ train_metrics[k].append(curr_train_metrics[k])
222
+
223
+ for k in curr_test_metrics.keys():
224
+ if not k in val_metrics:
225
+ val_metrics[k] = [curr_test_metrics[k]]
226
+ else:
227
+ val_metrics[k].append(curr_test_metrics[k])
228
+
229
+ if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]:
230
+ logging.info("Found best validation %s!" % base_metric)
231
+
232
+ utils.save_checkpoint(
233
+ checkpoint_file, epoch, model, optimizer, lr_scheduler,
234
+ train_metrics, val_metrics, data_parallel)
235
+ logging.info("Saved checkpoint at %s" % checkpoint_file)
236
+
237
+ utils.save_graph(train_metrics, val_metrics, args.exp_dir)
238
+
239
+ return train_metrics, val_metrics
240
+
241
+
242
+ except KeyboardInterrupt:
243
+ print("Interrupted")
244
+ except Exception as _: # pylint: disable=broad-except
245
+ import traceback # pylint: disable=import-outside-toplevel
246
+ traceback.print_exc()
247
+
248
+
249
+ if __name__ == '__main__':
250
+ parser = argparse.ArgumentParser()
251
+ # Data Params
252
+ parser.add_argument('exp_dir', type=str,
253
+ default='./experiments/fsd_mask_label_mult',
254
+ help="Path to save checkpoints and logs.")
255
+
256
+ parser.add_argument('--n_train_items', type=int, default=None,
257
+ help="Number of items to train on in each epoch")
258
+ parser.add_argument('--n_test_items', type=int, default=None,
259
+ help="Number of items to test.")
260
+ parser.add_argument('--start_epoch', type=int, default=0,
261
+ help="Start epoch")
262
+ parser.add_argument('--pretrain_path', type=str,
263
+ help="Path to pretrained weights")
264
+ parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
265
+ help="Whether to use cuda")
266
+ parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
267
+ help="List of GPU ids used for training. "
268
+ "Eg., --gpu_ids 2 4. All GPUs are used by default.")
269
+ parser.add_argument('--detect_anomaly', dest='detect_anomaly',
270
+ action='store_true',
271
+ help="Whether to use cuda")
272
+ parser.add_argument('--wandb', dest='wandb', action='store_true',
273
+ help="Whether to sync tensorboard to wandb")
274
+
275
+ args = parser.parse_args()
276
+
277
+ # Set the random seed for reproducible experiments
278
+ torch.manual_seed(230)
279
+ random.seed(230)
280
+ np.random.seed(230)
281
+ if args.use_cuda:
282
+ torch.cuda.manual_seed(230)
283
+
284
+ # Set up checkpoints
285
+ if not os.path.exists(args.exp_dir):
286
+ os.makedirs(args.exp_dir)
287
+
288
+ utils.set_logger(os.path.join(args.exp_dir, 'train.log'))
289
+
290
+ # Load model and training params
291
+ params = utils.Params(os.path.join(args.exp_dir, 'config.json'))
292
+ for k, v in params.__dict__.items():
293
+ vars(args)[k] = v
294
+
295
+ # Initialize tensorboard writer
296
+ tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard')
297
+ args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch)
298
+ if args.wandb:
299
+ import wandb
300
+ wandb.init(
301
+ project='Semaudio', sync_tensorboard=True,
302
+ dir=tensorboard_dir, name=os.path.basename(args.exp_dir))
303
+
304
+ exec("import %s as network" % args.model)
305
+ logging.info("Imported the model from '%s'." % args.model)
306
+
307
+ train(args)
308
+
309
+ args.writer.close()
310
+ if args.wandb:
311
+ wandb.finish()