Bandhav Veluri commited on
Commit
2db9aa5
β€’
1 Parent(s): b7c0655

Refactored source

Browse files
app.py CHANGED
@@ -6,7 +6,6 @@ import torch
6
  import torchaudio
7
  import gradio as gr
8
 
9
- from src.helpers import utils
10
  from src.training.dcc_tf import Net as Waveformer
11
 
12
  TARGETS = [
@@ -34,7 +33,7 @@ if not os.path.exists('default_ckpt.pt'):
34
  # Instantiate model
35
  params = utils.Params('default_config.json')
36
  model = Waveformer(**params.model_params)
37
- utils.load_checkpoint('default_ckpt.pt', model)
38
  model.eval()
39
 
40
  def waveformer(audio, label_choices):
 
6
  import torchaudio
7
  import gradio as gr
8
 
 
9
  from src.training.dcc_tf import Net as Waveformer
10
 
11
  TARGETS = [
 
33
  # Instantiate model
34
  params = utils.Params('default_config.json')
35
  model = Waveformer(**params.model_params)
36
+ model.load_state_dict(torch.load('default_ckpt.pt', map_location=torch.device('cpu')))
37
  model.eval()
38
 
39
  def waveformer(audio, label_choices):
src/training/dcc_tf.py β†’ dcc_tf.py RENAMED
File without changes
src/__init__.py DELETED
File without changes
src/helpers/__init__.py DELETED
File without changes
src/helpers/utils.py DELETED
@@ -1,205 +0,0 @@
1
- """A collection of useful helper functions"""
2
-
3
- import os
4
- import logging
5
- import json
6
-
7
- import torch
8
- from torch.profiler import profile, record_function, ProfilerActivity
9
- import pandas as pd
10
- from torchmetrics.functional import(
11
- scale_invariant_signal_noise_ratio as si_snr,
12
- signal_noise_ratio as snr,
13
- signal_distortion_ratio as sdr,
14
- scale_invariant_signal_distortion_ratio as si_sdr)
15
- import matplotlib.pyplot as plt
16
-
17
- class Params():
18
- """Class that loads hyperparameters from a json file.
19
- Example:
20
- ```
21
- params = Params(json_path)
22
- print(params.learning_rate)
23
- params.learning_rate = 0.5 # change the value of learning_rate in params
24
- ```
25
- """
26
-
27
- def __init__(self, json_path):
28
- with open(json_path) as f:
29
- params = json.load(f)
30
- self.__dict__.update(params)
31
-
32
- def save(self, json_path):
33
- with open(json_path, 'w') as f:
34
- json.dump(self.__dict__, f, indent=4)
35
-
36
- def update(self, json_path):
37
- """Loads parameters from json file"""
38
- with open(json_path) as f:
39
- params = json.load(f)
40
- self.__dict__.update(params)
41
-
42
- @property
43
- def dict(self):
44
- """Gives dict-like access to Params instance by `params.dict['learning_rate']"""
45
- return self.__dict__
46
-
47
- def save_graph(train_metrics, test_metrics, save_dir):
48
- metrics = [snr, si_snr]
49
- results = {'train_loss': train_metrics['loss'],
50
- 'test_loss' : test_metrics['loss']}
51
-
52
- for m_fn in metrics:
53
- results["train_"+m_fn.__name__] = train_metrics[m_fn.__name__]
54
- results["test_"+m_fn.__name__] = test_metrics[m_fn.__name__]
55
-
56
- results_pd = pd.DataFrame(results)
57
-
58
- results_pd.to_csv(os.path.join(save_dir, 'results.csv'))
59
-
60
- fig, temp_ax = plt.subplots(2, 3, figsize=(15,10))
61
- axs=[]
62
- for i in temp_ax:
63
- for j in i:
64
- axs.append(j)
65
-
66
- x = range(len(train_metrics['loss']))
67
- axs[0].plot(x, train_metrics['loss'], label='train')
68
- axs[0].plot(x, test_metrics['loss'], label='test')
69
- axs[0].set(ylabel='Loss')
70
- axs[0].set(xlabel='Epoch')
71
- axs[0].set_title('loss',fontweight='bold')
72
- axs[0].legend()
73
-
74
- for i in range(len(metrics)):
75
- axs[i+1].plot(x, train_metrics[metrics[i].__name__], label='train')
76
- axs[i+1].plot(x, test_metrics[metrics[i].__name__], label='test')
77
- axs[i+1].set(xlabel='Epoch')
78
- axs[i+1].set_title(metrics[i].__name__,fontweight='bold')
79
- axs[i+1].legend()
80
-
81
- plt.tight_layout()
82
- plt.savefig(os.path.join(save_dir, 'results.png'))
83
- plt.close(fig)
84
-
85
- def set_logger(log_path):
86
- """Set the logger to log info in terminal and file `log_path`.
87
- In general, it is useful to have a logger so that every output to the terminal is saved
88
- in a permanent file. Here we save it to `model_dir/train.log`.
89
- Example:
90
- ```
91
- logging.info("Starting training...")
92
- ```
93
- Args:
94
- log_path: (string) where to log
95
- """
96
- logger = logging.getLogger()
97
- logger.setLevel(logging.INFO)
98
- logger.handlers.clear()
99
-
100
- # Logging to a file
101
- file_handler = logging.FileHandler(log_path)
102
- file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
103
- logger.addHandler(file_handler)
104
-
105
- # Logging to console
106
- stream_handler = logging.StreamHandler()
107
- stream_handler.setFormatter(logging.Formatter('%(message)s'))
108
- logger.addHandler(stream_handler)
109
-
110
- def load_checkpoint(checkpoint, model, optim=None, lr_sched=None, data_parallel=False):
111
- """Loads model parameters (state_dict) from file_path.
112
-
113
- Args:
114
- checkpoint: (string) filename which needs to be loaded
115
- model: (torch.nn.Module) model for which the parameters are loaded
116
- data_parallel: (bool) if the model is a data parallel model
117
- """
118
- if not os.path.exists(checkpoint):
119
- raise("File doesn't exist {}".format(checkpoint))
120
-
121
- state_dict = torch.load(checkpoint)
122
-
123
- if data_parallel:
124
- state_dict['model_state_dict'] = {
125
- 'module.' + k: state_dict['model_state_dict'][k]
126
- for k in state_dict['model_state_dict'].keys()}
127
- model.load_state_dict(state_dict['model_state_dict'])
128
-
129
- if optim is not None:
130
- optim.load_state_dict(state_dict['optim_state_dict'])
131
-
132
- if lr_sched is not None:
133
- lr_sched.load_state_dict(state_dict['lr_sched_state_dict'])
134
-
135
- return state_dict['epoch'], state_dict['train_metrics'], \
136
- state_dict['val_metrics']
137
-
138
- def save_checkpoint(checkpoint, epoch, model, optim=None, lr_sched=None,
139
- train_metrics=None, val_metrics=None, data_parallel=False):
140
- """Saves model parameters (state_dict) to file_path.
141
-
142
- Args:
143
- checkpoint: (string) filename which needs to be loaded
144
- model: (torch.nn.Module) model for which the parameters are loaded
145
- data_parallel: (bool) if the model is a data parallel model
146
- """
147
- if os.path.exists(checkpoint):
148
- raise("File already exists {}".format(checkpoint))
149
-
150
- model_state_dict = model.state_dict()
151
- if data_parallel:
152
- model_state_dict = {
153
- k.partition('module.')[2]:
154
- model_state_dict[k] for k in model_state_dict.keys()}
155
-
156
- optim_state_dict = None if not optim else optim.state_dict()
157
- lr_sched_state_dict = None if not lr_sched else lr_sched.state_dict()
158
-
159
- state_dict = {
160
- 'epoch': epoch,
161
- 'model_state_dict': model_state_dict,
162
- 'optim_state_dict': optim_state_dict,
163
- 'lr_sched_state_dict': lr_sched_state_dict,
164
- 'train_metrics': train_metrics,
165
- 'val_metrics': val_metrics
166
- }
167
-
168
- torch.save(state_dict, checkpoint)
169
-
170
- def model_size(model):
171
- """
172
- Returns size of the `model` in millions of parameters.
173
- """
174
- num_train_params = sum(
175
- p.numel() for p in model.parameters() if p.requires_grad)
176
- return num_train_params / 1e6
177
-
178
- def run_time(model, inputs, profiling=False):
179
- """
180
- Returns runtime of a model in ms.
181
- """
182
- # Warmup
183
- for _ in range(100):
184
- output = model(*inputs)
185
-
186
- with profile(activities=[ProfilerActivity.CPU],
187
- record_shapes=True) as prof:
188
- with record_function("model_inference"):
189
- output = model(*inputs)
190
-
191
- # Print profiling results
192
- if profiling:
193
- print(prof.key_averages().table(sort_by="self_cpu_time_total",
194
- row_limit=20))
195
-
196
- # Return runtime in ms
197
- return prof.profiler.self_cpu_time_total / 1000
198
-
199
- def format_lr_info(optimizer):
200
- lr_info = ""
201
- for i, pg in enumerate(optimizer.param_groups):
202
- lr_info += " {group %d: params=%.5fM lr=%.1E}" % (
203
- i, sum([p.numel() for p in pg['params']]) / (1024 ** 2), pg['lr'])
204
- return lr_info
205
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/__init__.py DELETED
File without changes
src/training/eval.py DELETED
@@ -1,214 +0,0 @@
1
- """
2
- Test script to evaluate the model.
3
- """
4
-
5
- import argparse
6
- import importlib
7
- import multiprocessing
8
- import os, glob
9
- import logging
10
-
11
- import numpy as np
12
- import torch
13
- import pandas as pd
14
- import torch.nn as nn
15
- from torch.utils.tensorboard import SummaryWriter
16
- from torch.profiler import profile, record_function, ProfilerActivity
17
- from tqdm import tqdm # pylint: disable=unused-import
18
- from torchmetrics.functional import(
19
- scale_invariant_signal_noise_ratio as si_snr,
20
- signal_noise_ratio as snr,
21
- signal_distortion_ratio as sdr,
22
- scale_invariant_signal_distortion_ratio as si_sdr)
23
-
24
- from src.helpers import utils
25
- from src.training.synthetic_dataset import FSDSoundScapesDataset, tensorboard_add_metrics
26
- from src.training.synthetic_dataset import tensorboard_add_sample
27
-
28
- def test_epoch(model: nn.Module, device: torch.device,
29
- test_loader: torch.utils.data.dataloader.DataLoader,
30
- n_items: int, loss_fn, metrics_fn,
31
- profiling: bool = False, epoch: int = 0,
32
- writer: SummaryWriter = None, data_params = None) -> float:
33
- """
34
- Evaluate the network.
35
- """
36
- model.eval()
37
- metrics = {}
38
-
39
- with torch.no_grad():
40
- for batch_idx, (mixed, label, gt) in \
41
- enumerate(tqdm(test_loader, desc='Test', ncols=100)):
42
- mixed = mixed.to(device)
43
- label = label.to(device)
44
- gt = gt.to(device)
45
-
46
- # Run through the model
47
- with profile(activities=[ProfilerActivity.CPU],
48
- record_shapes=True) as prof:
49
- with record_function("model_inference"):
50
- output = model(mixed, label)
51
- if profiling:
52
- logging.info(
53
- prof.key_averages().table(sort_by="self_cpu_time_total",
54
- row_limit=20))
55
-
56
- # Compute loss
57
- loss = loss_fn(output, gt)
58
-
59
- # Compute metrics
60
- metrics_batch = metrics_fn(mixed, output, gt)
61
- metrics_batch['loss'] = [loss.item()]
62
- metrics_batch['runtime'] = [prof.profiler.self_cpu_time_total/1000]
63
- for k in metrics_batch.keys():
64
- if not k in metrics:
65
- metrics[k] = metrics_batch[k]
66
- else:
67
- metrics[k] += metrics_batch[k]
68
-
69
- if writer is not None:
70
- if batch_idx == 0:
71
- tensorboard_add_sample(
72
- writer, tag='Test',
73
- sample=(mixed[:8], label[:8], gt[:8], output[:8]),
74
- step=epoch, params=data_params)
75
- tensorboard_add_metrics(
76
- writer, tag='Test', metrics=metrics_batch, label=label,
77
- step=epoch)
78
-
79
- if n_items is not None and batch_idx == (n_items - 1):
80
- break
81
-
82
- avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
83
- avg_metrics_str = "Test:"
84
- for m in avg_metrics.keys():
85
- avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
86
- logging.info(avg_metrics_str)
87
-
88
- return avg_metrics
89
-
90
- def evaluate(network, args: argparse.Namespace):
91
- """
92
- Evaluate the model on a given dataset.
93
- """
94
-
95
- # Load dataset
96
- data_test = FSDSoundScapesDataset(**args.test_data)
97
- logging.info("Loaded test dataset at %s containing %d elements" %
98
- (args.test_data['input_dir'], len(data_test)))
99
-
100
- # Set up the device and workers.
101
- use_cuda = args.use_cuda and torch.cuda.is_available()
102
- if use_cuda:
103
- gpu_ids = args.gpu_ids if args.gpu_ids is not None\
104
- else range(torch.cuda.device_count())
105
- device_ids = [_ for _ in gpu_ids]
106
- data_parallel = len(device_ids) > 1
107
- device = 'cuda:%d' % device_ids[0]
108
- torch.cuda.set_device(device_ids[0])
109
- logging.info("Using CUDA devices: %s" % str(device_ids))
110
- else:
111
- data_parallel = False
112
- device = torch.device('cpu')
113
- logging.info("Using device: CPU")
114
-
115
- # Set multiprocessing params
116
- num_workers = min(multiprocessing.cpu_count(), args.n_workers)
117
- kwargs = {
118
- 'num_workers': num_workers,
119
- 'pin_memory': True
120
- } if use_cuda else {}
121
-
122
- # Set up data loader
123
- test_loader = torch.utils.data.DataLoader(data_test,
124
- batch_size=args.eval_batch_size,
125
- **kwargs)
126
-
127
- # Set up model
128
- model = network.Net(**args.model_params)
129
- if use_cuda and data_parallel:
130
- model = nn.DataParallel(model, device_ids=device_ids)
131
- logging.info("Using data parallel model")
132
- model.to(device)
133
-
134
- # Load weights
135
- if args.pretrain_path == "best":
136
- ckpts = glob.glob(os.path.join(args.exp_dir, '*.pt'))
137
- ckpts.sort(
138
- key=lambda _: int(os.path.splitext(os.path.basename(_))[0]))
139
- val_metrics = torch.load(ckpts[-1])['val_metrics'][args.base_metric]
140
- best_epoch = max(range(len(val_metrics)), key=val_metrics.__getitem__)
141
- args.pretrain_path = os.path.join(args.exp_dir, '%d.pt' % best_epoch)
142
- logging.info(
143
- "Found 'best' validation %s=%.02f at %s" %
144
- (args.base_metric, val_metrics[best_epoch], args.pretrain_path))
145
- if args.pretrain_path != "":
146
- utils.load_checkpoint(
147
- args.pretrain_path, model, data_parallel=data_parallel)
148
- logging.info("Loaded pretrain weights from %s" % args.pretrain_path)
149
-
150
- # Evaluate
151
- try:
152
- return test_epoch(
153
- model, device, test_loader, args.n_items, network.loss,
154
- network.metrics, args.profiling)
155
- except KeyboardInterrupt:
156
- print("Interrupted")
157
- except Exception as _: # pylint: disable=broad-except
158
- import traceback # pylint: disable=import-outside-toplevel
159
- traceback.print_exc()
160
-
161
-
162
- if __name__ == '__main__':
163
- parser = argparse.ArgumentParser()
164
- # Data Params
165
- parser.add_argument('experiments', nargs='+', type=str,
166
- default=None,
167
- help="List of experiments to evaluate. "
168
- "Provide only one experiment when providing "
169
- "pretrained path. If pretrianed path is not "
170
- "provided, epoch with best validation metric "
171
- "is used for evaluation.")
172
- parser.add_argument('--results', type=str, default="",
173
- help="Path to the CSV file to store results.")
174
-
175
- # System params
176
- parser.add_argument('--n_items', type=int, default=None,
177
- help="Number of items to test.")
178
- parser.add_argument('--pretrain_path', type=str, default="best",
179
- help="Path to pretrained weights")
180
- parser.add_argument('--profiling', dest='profiling', action='store_true',
181
- help="Enable or disable profiling.")
182
- parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
183
- help="Whether to use cuda")
184
- parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
185
- help="List of GPU ids used for training. "
186
- "Eg., --gpu_ids 2 4. All GPUs are used by default.")
187
- args = parser.parse_args()
188
-
189
- results = []
190
-
191
- for exp_dir in args.experiments:
192
- eval_args = argparse.Namespace(**vars(args))
193
- eval_args.exp_dir = exp_dir
194
-
195
- utils.set_logger(os.path.join(exp_dir, 'eval.log'))
196
- logging.info("Evaluating %s ..." % exp_dir)
197
-
198
- # Load model and training params
199
- params = utils.Params(os.path.join(exp_dir, 'config.json'))
200
- for k, v in params.__dict__.items():
201
- vars(eval_args)[k] = v
202
-
203
- network = importlib.import_module(eval_args.model)
204
- logging.info("Imported the model from '%s'." % eval_args.model)
205
-
206
- curr_res = evaluate(network, eval_args)
207
- curr_res['experiment'] = os.path.basename(exp_dir)
208
- results.append(curr_res)
209
-
210
- del eval_args
211
-
212
- if args.results != "":
213
- print("Writing results to %s" % args.results)
214
- pd.DataFrame(results).to_csv(args.results, index=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/synthetic_dataset.py DELETED
@@ -1,168 +0,0 @@
1
- """
2
- Torch dataset object for synthetically rendered spatial data.
3
- """
4
-
5
- import os
6
- import json
7
- import random
8
- from pathlib import Path
9
- import logging
10
-
11
- import numpy as np
12
- import pandas as pd
13
- import matplotlib.pyplot as plt
14
- import scaper
15
- import torch
16
- import torchaudio
17
- import torchaudio.transforms as AT
18
- from random import randrange
19
-
20
- class FSDSoundScapesDataset(torch.utils.data.Dataset): # type: ignore
21
- """
22
- Base class for FSD Sound Scapes dataset
23
- """
24
-
25
- _labels = [
26
- "Acoustic_guitar", "Applause", "Bark", "Bass_drum",
27
- "Burping_or_eructation", "Bus", "Cello", "Chime", "Clarinet",
28
- "Computer_keyboard", "Cough", "Cowbell", "Double_bass",
29
- "Drawer_open_or_close", "Electric_piano", "Fart", "Finger_snapping",
30
- "Fireworks", "Flute", "Glockenspiel", "Gong", "Gunshot_or_gunfire",
31
- "Harmonica", "Hi-hat", "Keys_jangling", "Knock", "Laughter", "Meow",
32
- "Microwave_oven", "Oboe", "Saxophone", "Scissors", "Shatter",
33
- "Snare_drum", "Squeak", "Tambourine", "Tearing", "Telephone",
34
- "Trumpet", "Violin_or_fiddle", "Writing"]
35
-
36
- def __init__(self, input_dir, dset='', sr=None,
37
- resample_rate=None, max_num_targets=1):
38
- assert dset in ['train', 'val', 'test'], \
39
- "`dset` must be one of ['train', 'val', 'test']"
40
- self.dset = dset
41
- self.max_num_targets = max_num_targets
42
- self.fg_dir = os.path.join(input_dir, 'FSDKaggle2018/%s' % dset)
43
- if dset in ['train', 'val']:
44
- self.bg_dir = os.path.join(
45
- input_dir,
46
- 'TAU-acoustic-sounds/'
47
- 'TAU-urban-acoustic-scenes-2019-development')
48
- else:
49
- self.bg_dir = os.path.join(
50
- input_dir,
51
- 'TAU-acoustic-sounds/'
52
- 'TAU-urban-acoustic-scenes-2019-evaluation')
53
- logging.info("Loading %s dataset: fg_dir=%s bg_dir=%s" %
54
- (dset, self.fg_dir, self.bg_dir))
55
-
56
- self.samples = sorted(list(
57
- Path(os.path.join(input_dir, 'jams', dset)).glob('[0-9]*')))
58
-
59
- jamsfile = os.path.join(self.samples[0], 'mixture.jams')
60
- _, jams, _, _ = scaper.generate_from_jams(
61
- jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
62
- _sr = jams['annotations'][0]['sandbox']['scaper']['sr']
63
- assert _sr == sr, "Sampling rate provided does not match the data"
64
-
65
- if resample_rate is not None:
66
- self.resampler = AT.Resample(sr, resample_rate)
67
- self.sr = resample_rate
68
- else:
69
- self.resampler = lambda a: a
70
- self.sr = sr
71
-
72
- def _get_label_vector(self, labels):
73
- """
74
- Generates a multi-hot vector corresponding to `labels`.
75
- """
76
- vector = torch.zeros(len(FSDSoundScapesDataset._labels))
77
-
78
- for label in labels:
79
- idx = FSDSoundScapesDataset._labels.index(label)
80
- assert vector[idx] == 0, "Repeated labels"
81
- vector[idx] = 1
82
-
83
- return vector
84
-
85
- def __len__(self):
86
- return len(self.samples)
87
-
88
- def __getitem__(self, idx):
89
- sample_path = self.samples[idx]
90
- jamsfile = os.path.join(sample_path, 'mixture.jams')
91
-
92
- mixture, jams, ann_list, event_audio_list = scaper.generate_from_jams(
93
- jamsfile, fg_path=self.fg_dir, bg_path=self.bg_dir)
94
- isolated_events = {}
95
- for e, a in zip(ann_list, event_audio_list[1:]):
96
- # 0th event is background
97
- isolated_events[e[2]] = a
98
- gt_events = list(pd.read_csv(
99
- os.path.join(sample_path, 'gt_events.csv'), sep='\t')['label'])
100
-
101
- mixture = torch.from_numpy(mixture).permute(1, 0)
102
- mixture = self.resampler(mixture.to(torch.float))
103
-
104
- if self.dset == 'train':
105
- labels = random.sample(gt_events, randrange(1,self.max_num_targets+1))
106
- elif self.dset == 'val':
107
- labels = gt_events[:idx%self.max_num_targets+1]
108
- elif self.dset == 'test':
109
- labels = gt_events[:self.max_num_targets]
110
- label_vector = self._get_label_vector(labels)
111
-
112
- gt = torch.zeros_like(
113
- torch.from_numpy(event_audio_list[1]).permute(1, 0))
114
- for l in labels:
115
- gt = gt + torch.from_numpy(isolated_events[l]).permute(1, 0)
116
- gt = self.resampler(gt.to(torch.float))
117
-
118
- return mixture, label_vector, gt #, jams
119
-
120
- def tensorboard_add_sample(writer, tag, sample, step, params):
121
- """
122
- Adds a sample of FSDSynthDataset to tensorboard.
123
- """
124
- if params['resample_rate'] is not None:
125
- sr = params['resample_rate']
126
- else:
127
- sr = params['sr']
128
- resample_rate = 16000 if sr > 16000 else sr
129
-
130
- m, l, gt, o = sample
131
- m, gt, o = (
132
- torchaudio.functional.resample(_, sr, resample_rate).cpu()
133
- for _ in (m, gt, o))
134
-
135
- def _add_audio(a, audio_tag, axis, plt_title):
136
- for i, ch in enumerate(a):
137
- axis.plot(ch, label='mic %d' % i)
138
- writer.add_audio(
139
- '%s/mic %d' % (audio_tag, i), ch.unsqueeze(0), step, resample_rate)
140
- axis.set_title(plt_title)
141
- axis.legend()
142
-
143
- for b in range(m.shape[0]):
144
- label = []
145
- for i in range(len(l[b, :])):
146
- if l[b, i] == 1:
147
- label.append(FSDSoundScapesDataset._labels[i])
148
-
149
- # Add waveforms
150
- rows = 3 # input, output, gt
151
- fig = plt.figure(figsize=(10, 2 * rows))
152
- axes = fig.subplots(rows, 1, sharex=True)
153
- _add_audio(m[b], '%s/sample_%d/0_input' % (tag, b), axes[0], "Mixed")
154
- _add_audio(o[b], '%s/sample_%d/1_output' % (tag, b), axes[1], "Output (%s)" % label)
155
- _add_audio(gt[b], '%s/sample_%d/2_gt' % (tag, b), axes[2], "GT (%s)" % label)
156
- writer.add_figure('%s/sample_%d/waveform' % (tag, b), fig, step)
157
-
158
- def tensorboard_add_metrics(writer, tag, metrics, label, step):
159
- """
160
- Add metrics to tensorboard.
161
- """
162
- vals = np.asarray(metrics['scale_invariant_signal_noise_ratio'])
163
-
164
- writer.add_histogram('%s/%s' % (tag, 'SI-SNRi'), vals, step)
165
-
166
- label_names = [FSDSoundScapesDataset._labels[torch.argmax(_)] for _ in label]
167
- for l, v in zip(label_names, vals):
168
- writer.add_histogram('%s/%s' % (tag, l), v, step)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/training/train.py DELETED
@@ -1,311 +0,0 @@
1
- """
2
- The main training script for training on synthetic data
3
- """
4
-
5
- import argparse
6
- import multiprocessing
7
- import os
8
- import logging
9
- from pathlib import Path
10
- import random
11
-
12
- import numpy as np
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- import torch.optim as optim
17
- from torch.utils.tensorboard import SummaryWriter
18
- from tqdm import tqdm # pylint: disable=unused-import
19
- from torchmetrics.functional import(
20
- scale_invariant_signal_noise_ratio as si_snr,
21
- signal_noise_ratio as snr,
22
- signal_distortion_ratio as sdr,
23
- scale_invariant_signal_distortion_ratio as si_sdr)
24
-
25
- from src.helpers import utils
26
- from src.training.eval import test_epoch
27
- from src.training.synthetic_dataset import FSDSoundScapesDataset as Dataset
28
- from src.training.synthetic_dataset import tensorboard_add_sample
29
-
30
- def train_epoch(model: nn.Module, device: torch.device,
31
- optimizer: optim.Optimizer,
32
- train_loader: torch.utils.data.dataloader.DataLoader,
33
- n_items: int, epoch: int = 0,
34
- writer: SummaryWriter = None, data_params = None) -> float:
35
-
36
- """
37
- Train a single epoch.
38
- """
39
- # Set the model to training.
40
- model.train()
41
-
42
- # Training loop
43
- losses = []
44
- metrics = {}
45
-
46
- with tqdm(total=len(train_loader), desc='Train', ncols=100) as t:
47
- for batch_idx, (mixed, label, gt) in enumerate(train_loader):
48
- mixed = mixed.to(device)
49
- label = label.to(device)
50
- gt = gt.to(device)
51
-
52
- # Reset grad
53
- optimizer.zero_grad()
54
-
55
- # Run through the model
56
- output = model(mixed, label)
57
-
58
- # Compute loss
59
- loss = network.loss(output, gt)
60
-
61
- losses.append(loss.item())
62
-
63
- # Backpropagation
64
- loss.backward()
65
-
66
- # Gradient clipping
67
- torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
68
-
69
- # Update the weights
70
- optimizer.step()
71
-
72
- metrics_batch = network.metrics(mixed.detach(), output.detach(),
73
- gt.detach())
74
- for k in metrics_batch.keys():
75
- if not k in metrics:
76
- metrics[k] = metrics_batch[k]
77
- else:
78
- metrics[k] += metrics_batch[k]
79
-
80
- if writer is not None and batch_idx == 0:
81
- tensorboard_add_sample(
82
- writer, tag='Train',
83
- sample=(mixed.detach()[:8], label.detach()[:8],
84
- gt.detach()[:8], output.detach()[:8]),
85
- step=epoch, params=data_params)
86
-
87
- # Show current loss in the progress meter
88
- t.set_postfix(loss='%.05f'%loss.item())
89
- t.update()
90
-
91
- if n_items is not None and batch_idx == n_items:
92
- break
93
-
94
- avg_metrics = {k: np.mean(metrics[k]) for k in metrics.keys()}
95
- avg_metrics['loss'] = np.mean(losses)
96
- avg_metrics_str = "Train:"
97
- for m in avg_metrics.keys():
98
- avg_metrics_str += ' %s=%.04f' % (m, avg_metrics[m])
99
- logging.info(avg_metrics_str)
100
-
101
- return avg_metrics
102
-
103
-
104
- def train(args: argparse.Namespace):
105
- """
106
- Train the network.
107
- """
108
-
109
- # Load dataset
110
- data_train = Dataset(**args.train_data)
111
- logging.info("Loaded train dataset at %s containing %d elements" %
112
- (args.train_data['input_dir'], len(data_train)))
113
- data_val = Dataset(**args.val_data)
114
- logging.info("Loaded test dataset at %s containing %d elements" %
115
- (args.val_data['input_dir'], len(data_val)))
116
-
117
- # Set up the device and workers.
118
- use_cuda = args.use_cuda and torch.cuda.is_available()
119
- if use_cuda:
120
- gpu_ids = args.gpu_ids if args.gpu_ids is not None\
121
- else range(torch.cuda.device_count())
122
- device_ids = [_ for _ in gpu_ids]
123
- data_parallel = len(device_ids) > 1
124
- device = 'cuda:%d' % device_ids[0]
125
- torch.cuda.set_device(device_ids[0])
126
- logging.info("Using CUDA devices: %s" % str(device_ids))
127
- else:
128
- data_parallel = False
129
- device = torch.device('cpu')
130
- logging.info("Using device: CPU")
131
-
132
- # Set multiprocessing params
133
- num_workers = min(multiprocessing.cpu_count(), args.n_workers)
134
- kwargs = {
135
- 'num_workers': num_workers,
136
- 'pin_memory': True
137
- } if use_cuda else {}
138
-
139
- # Set up data loaders
140
- #print(args.batch_size, args.eval_batch_size)
141
- train_loader = torch.utils.data.DataLoader(data_train,
142
- batch_size=args.batch_size,
143
- shuffle=True, **kwargs)
144
- val_loader = torch.utils.data.DataLoader(data_val,
145
- batch_size=args.eval_batch_size,
146
- **kwargs)
147
-
148
- # Set up model
149
- model = network.Net(**args.model_params)
150
-
151
- # Add graph to tensorboard with example train samples
152
- # _mixed, _label, _ = next(iter(val_loader))
153
- # args.writer.add_graph(model, (_mixed, _label))
154
-
155
- if use_cuda and data_parallel:
156
- model = nn.DataParallel(model, device_ids=device_ids)
157
- logging.info("Using data parallel model")
158
- model.to(device)
159
-
160
- # Set up the optimizer
161
- logging.info("Initializing optimizer with %s" % str(args.optim))
162
- optimizer = network.optimizer(model, **args.optim, data_parallel=data_parallel)
163
- logging.info('Learning rates initialized to:' + utils.format_lr_info(optimizer))
164
-
165
- lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
166
- optimizer, **args.lr_sched)
167
- logging.info("Initialized LR scheduler with params: fix_lr_epochs=%d %s"
168
- % (args.fix_lr_epochs, str(args.lr_sched)))
169
-
170
- base_metric = args.base_metric
171
- train_metrics = {}
172
- val_metrics = {}
173
-
174
- # Load the model if `args.start_epoch` is greater than 0. This will load the
175
- # model from epoch = `args.start_epoch - 1`
176
- assert args.start_epoch >=0, "start_epoch must be greater than 0."
177
- if args.start_epoch > 0:
178
- checkpoint_path = os.path.join(args.exp_dir,
179
- '%d.pt' % (args.start_epoch - 1))
180
- _, train_metrics, val_metrics = utils.load_checkpoint(
181
- checkpoint_path, model, optim=optimizer, lr_sched=lr_scheduler,
182
- data_parallel=data_parallel)
183
- logging.info("Loaded checkpoint from %s" % checkpoint_path)
184
- logging.info("Learning rates restored to:" + utils.format_lr_info(optimizer))
185
-
186
- # Training loop
187
- try:
188
- torch.autograd.set_detect_anomaly(args.detect_anomaly)
189
- for epoch in range(args.start_epoch, args.epochs + 1):
190
- logging.info("Epoch %d:" % epoch)
191
- checkpoint_file = os.path.join(args.exp_dir, '%d.pt' % epoch)
192
- assert not os.path.exists(checkpoint_file), \
193
- "Checkpoint file %s already exists" % checkpoint_file
194
- #print("---- begin trianivg")
195
- curr_train_metrics = train_epoch(model, device, optimizer,
196
- train_loader, args.n_train_items,
197
- epoch=epoch, writer=args.writer,
198
- data_params=args.train_data)
199
- #raise KeyboardInterrupt
200
- curr_test_metrics = test_epoch(model, device, val_loader,
201
- args.n_test_items, network.loss,
202
- network.metrics, epoch=epoch,
203
- writer=args.writer,
204
- data_params=args.val_data)
205
- # LR scheduler
206
- if epoch >= args.fix_lr_epochs:
207
- lr_scheduler.step(curr_test_metrics[base_metric])
208
- logging.info(
209
- "LR after scheduling step: %s" %
210
- [_['lr'] for _ in optimizer.param_groups])
211
-
212
- # Write metrics to tensorboard
213
- args.writer.add_scalars('Train', curr_train_metrics, epoch)
214
- args.writer.add_scalars('Val', curr_test_metrics, epoch)
215
- args.writer.flush()
216
-
217
- for k in curr_train_metrics.keys():
218
- if not k in train_metrics:
219
- train_metrics[k] = [curr_train_metrics[k]]
220
- else:
221
- train_metrics[k].append(curr_train_metrics[k])
222
-
223
- for k in curr_test_metrics.keys():
224
- if not k in val_metrics:
225
- val_metrics[k] = [curr_test_metrics[k]]
226
- else:
227
- val_metrics[k].append(curr_test_metrics[k])
228
-
229
- if max(val_metrics[base_metric]) == val_metrics[base_metric][-1]:
230
- logging.info("Found best validation %s!" % base_metric)
231
-
232
- utils.save_checkpoint(
233
- checkpoint_file, epoch, model, optimizer, lr_scheduler,
234
- train_metrics, val_metrics, data_parallel)
235
- logging.info("Saved checkpoint at %s" % checkpoint_file)
236
-
237
- utils.save_graph(train_metrics, val_metrics, args.exp_dir)
238
-
239
- return train_metrics, val_metrics
240
-
241
-
242
- except KeyboardInterrupt:
243
- print("Interrupted")
244
- except Exception as _: # pylint: disable=broad-except
245
- import traceback # pylint: disable=import-outside-toplevel
246
- traceback.print_exc()
247
-
248
-
249
- if __name__ == '__main__':
250
- parser = argparse.ArgumentParser()
251
- # Data Params
252
- parser.add_argument('exp_dir', type=str,
253
- default='./experiments/fsd_mask_label_mult',
254
- help="Path to save checkpoints and logs.")
255
-
256
- parser.add_argument('--n_train_items', type=int, default=None,
257
- help="Number of items to train on in each epoch")
258
- parser.add_argument('--n_test_items', type=int, default=None,
259
- help="Number of items to test.")
260
- parser.add_argument('--start_epoch', type=int, default=0,
261
- help="Start epoch")
262
- parser.add_argument('--pretrain_path', type=str,
263
- help="Path to pretrained weights")
264
- parser.add_argument('--use_cuda', dest='use_cuda', action='store_true',
265
- help="Whether to use cuda")
266
- parser.add_argument('--gpu_ids', nargs='+', type=int, default=None,
267
- help="List of GPU ids used for training. "
268
- "Eg., --gpu_ids 2 4. All GPUs are used by default.")
269
- parser.add_argument('--detect_anomaly', dest='detect_anomaly',
270
- action='store_true',
271
- help="Whether to use cuda")
272
- parser.add_argument('--wandb', dest='wandb', action='store_true',
273
- help="Whether to sync tensorboard to wandb")
274
-
275
- args = parser.parse_args()
276
-
277
- # Set the random seed for reproducible experiments
278
- torch.manual_seed(230)
279
- random.seed(230)
280
- np.random.seed(230)
281
- if args.use_cuda:
282
- torch.cuda.manual_seed(230)
283
-
284
- # Set up checkpoints
285
- if not os.path.exists(args.exp_dir):
286
- os.makedirs(args.exp_dir)
287
-
288
- utils.set_logger(os.path.join(args.exp_dir, 'train.log'))
289
-
290
- # Load model and training params
291
- params = utils.Params(os.path.join(args.exp_dir, 'config.json'))
292
- for k, v in params.__dict__.items():
293
- vars(args)[k] = v
294
-
295
- # Initialize tensorboard writer
296
- tensorboard_dir = os.path.join(args.exp_dir, 'tensorboard')
297
- args.writer = SummaryWriter(tensorboard_dir, purge_step=args.start_epoch)
298
- if args.wandb:
299
- import wandb
300
- wandb.init(
301
- project='Semaudio', sync_tensorboard=True,
302
- dir=tensorboard_dir, name=os.path.basename(args.exp_dir))
303
-
304
- exec("import %s as network" % args.model)
305
- logging.info("Imported the model from '%s'." % args.model)
306
-
307
- train(args)
308
-
309
- args.writer.close()
310
- if args.wandb:
311
- wandb.finish()