victan commited on
Commit
3ce48af
1 Parent(s): cc83e80

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +302 -0
train.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ import json
4
+ import logging
5
+ import os
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.data
12
+
13
+ from lib import dataset
14
+ from lib import nets
15
+ from lib import spec_utils
16
+
17
+
18
+ def setup_logger(name, logfile='LOGFILENAME.log'):
19
+ logger = logging.getLogger(name)
20
+ logger.setLevel(logging.DEBUG)
21
+ logger.propagate = False
22
+
23
+ fh = logging.FileHandler(logfile, encoding='utf8')
24
+ fh.setLevel(logging.DEBUG)
25
+ fh_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
26
+ fh.setFormatter(fh_formatter)
27
+
28
+ sh = logging.StreamHandler()
29
+ sh.setLevel(logging.INFO)
30
+
31
+ logger.addHandler(fh)
32
+ logger.addHandler(sh)
33
+
34
+ return logger
35
+
36
+
37
+ def to_wave(spec, n_fft, hop_length, window):
38
+ B, _, N, T = spec.shape
39
+ wave = spec.reshape(-1, N, T)
40
+ wave = torch.istft(wave, n_fft, hop_length, window=window)
41
+ wave = wave.reshape(B, 2, -1)
42
+
43
+ return wave
44
+
45
+
46
+ def sdr_loss(y, y_pred, eps=1e-8):
47
+ sdr = (y * y_pred).sum()
48
+ sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps
49
+
50
+ return -sdr
51
+
52
+
53
+ def weighted_sdr_loss(y, y_pred, n, n_pred, eps=1e-8):
54
+ y_sdr = (y * y_pred).sum()
55
+ y_sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps
56
+
57
+ noise_sdr = (n * n_pred).sum()
58
+ noise_sdr /= torch.linalg.norm(n) * torch.linalg.norm(n_pred) + eps
59
+
60
+ a = torch.sum(y ** 2)
61
+ a /= torch.sum(y ** 2) + torch.sum(n ** 2) + eps
62
+
63
+ loss = a * y_sdr + (1 - a) * noise_sdr
64
+
65
+ return -loss
66
+
67
+
68
+ def train_epoch(dataloader, model, device, optimizer, accumulation_steps):
69
+ model.train()
70
+ n_fft = model.n_fft
71
+ hop_length = model.hop_length
72
+ window = torch.hann_window(n_fft).to(device)
73
+
74
+ sum_loss = 0
75
+ crit_l1 = nn.L1Loss()
76
+
77
+ for itr, (X_batch, y_batch) in enumerate(dataloader):
78
+ X_batch = X_batch.to(device)
79
+ y_batch = y_batch.to(device)
80
+
81
+ mask = model(X_batch)
82
+
83
+ y_pred = X_batch * mask
84
+ y_wave_batch = to_wave(y_batch, n_fft, hop_length, window)
85
+ y_wave_pred = to_wave(y_pred, n_fft, hop_length, window)
86
+
87
+ loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred))
88
+ loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01
89
+
90
+ accum_loss = loss / accumulation_steps
91
+ accum_loss.backward()
92
+
93
+ if (itr + 1) % accumulation_steps == 0:
94
+ optimizer.step()
95
+ model.zero_grad()
96
+
97
+ sum_loss += loss.item() * len(X_batch)
98
+
99
+ # the rest batch
100
+ if (itr + 1) % accumulation_steps != 0:
101
+ optimizer.step()
102
+ model.zero_grad()
103
+
104
+ return sum_loss / len(dataloader.dataset)
105
+
106
+
107
+ def validate_epoch(dataloader, model, device):
108
+ model.eval()
109
+ n_fft = model.n_fft
110
+ hop_length = model.hop_length
111
+ window = torch.hann_window(n_fft).to(device)
112
+
113
+ sum_loss = 0
114
+ crit_l1 = nn.L1Loss()
115
+
116
+ with torch.no_grad():
117
+ for X_batch, y_batch in dataloader:
118
+ X_batch = X_batch.to(device)
119
+ y_batch = y_batch.to(device)
120
+
121
+ y_pred = model.predict(X_batch)
122
+
123
+ y_batch = spec_utils.crop_center(y_batch, y_pred)
124
+ y_wave_batch = to_wave(y_batch, n_fft, hop_length, window)
125
+ y_wave_pred = to_wave(y_pred, n_fft, hop_length, window)
126
+
127
+ loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred))
128
+ loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01
129
+
130
+ sum_loss += loss.item() * len(X_batch)
131
+
132
+ return sum_loss / len(dataloader.dataset)
133
+
134
+
135
+ def main():
136
+ p = argparse.ArgumentParser()
137
+ p.add_argument('--gpu', '-g', type=int, default=-1)
138
+ p.add_argument('--seed', '-s', type=int, default=2019)
139
+ p.add_argument('--sr', '-r', type=int, default=44100)
140
+ p.add_argument('--hop_length', '-H', type=int, default=1024)
141
+ p.add_argument('--n_fft', '-f', type=int, default=2048)
142
+ p.add_argument('--dataset', '-d', required=True)
143
+ p.add_argument('--split_mode', '-S', type=str, choices=['random', 'subdirs'], default='random')
144
+ p.add_argument('--learning_rate', '-l', type=float, default=0.001)
145
+ p.add_argument('--lr_min', type=float, default=0.0001)
146
+ p.add_argument('--lr_decay_factor', type=float, default=0.9)
147
+ p.add_argument('--lr_decay_patience', type=int, default=6)
148
+ p.add_argument('--batchsize', '-B', type=int, default=4)
149
+ p.add_argument('--accumulation_steps', '-A', type=int, default=1)
150
+ p.add_argument('--cropsize', '-C', type=int, default=256)
151
+ p.add_argument('--patches', '-p', type=int, default=16)
152
+ p.add_argument('--val_rate', '-v', type=float, default=0.2)
153
+ p.add_argument('--val_filelist', '-V', type=str, default=None)
154
+ p.add_argument('--val_batchsize', '-b', type=int, default=4)
155
+ p.add_argument('--val_cropsize', '-c', type=int, default=256)
156
+ p.add_argument('--num_workers', '-w', type=int, default=4)
157
+ p.add_argument('--epoch', '-E', type=int, default=200)
158
+ p.add_argument('--reduction_rate', '-R', type=float, default=0.0)
159
+ p.add_argument('--reduction_level', '-L', type=float, default=0.2)
160
+ p.add_argument('--mixup_rate', '-M', type=float, default=0.0)
161
+ p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
162
+ p.add_argument('--pretrained_model', '-P', type=str, default=None)
163
+ p.add_argument('--debug', action='store_true')
164
+ args = p.parse_args()
165
+
166
+ logger.debug(vars(args))
167
+
168
+ random.seed(args.seed)
169
+ np.random.seed(args.seed)
170
+ torch.manual_seed(args.seed)
171
+
172
+ val_filelist = []
173
+ if args.val_filelist is not None:
174
+ with open(args.val_filelist, 'r', encoding='utf8') as f:
175
+ val_filelist = json.load(f)
176
+
177
+ train_filelist, val_filelist = dataset.train_val_split(
178
+ dataset_dir=args.dataset,
179
+ split_mode=args.split_mode,
180
+ val_rate=args.val_rate,
181
+ val_filelist=val_filelist
182
+ )
183
+
184
+ if args.debug:
185
+ logger.info('### DEBUG MODE')
186
+ train_filelist = train_filelist[:1]
187
+ val_filelist = val_filelist[:1]
188
+ elif args.val_filelist is None and args.split_mode == 'random':
189
+ with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
190
+ json.dump(val_filelist, f, ensure_ascii=False)
191
+
192
+ for i, (X_fname, y_fname) in enumerate(val_filelist):
193
+ logger.info('{} {} {}'.format(i + 1, os.path.basename(X_fname), os.path.basename(y_fname)))
194
+
195
+ bins = args.n_fft // 2 + 1
196
+ freq_to_bin = 2 * bins / args.sr
197
+ unstable_bins = int(200 * freq_to_bin)
198
+ stable_bins = int(22050 * freq_to_bin)
199
+ reduction_weight = np.concatenate([
200
+ np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None],
201
+ np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None],
202
+ np.zeros((bins - stable_bins, 1), dtype=np.float32),
203
+ ], axis=0) * args.reduction_level
204
+
205
+ device = torch.device('cpu')
206
+ model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128, True)
207
+ if args.pretrained_model is not None:
208
+ model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
209
+ if torch.cuda.is_available() and args.gpu >= 0:
210
+ device = torch.device('cuda:{}'.format(args.gpu))
211
+ model.to(device)
212
+
213
+ optimizer = torch.optim.Adam(
214
+ filter(lambda p: p.requires_grad, model.parameters()),
215
+ lr=args.learning_rate
216
+ )
217
+
218
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
219
+ optimizer,
220
+ factor=args.lr_decay_factor,
221
+ patience=args.lr_decay_patience,
222
+ threshold=1e-6,
223
+ min_lr=args.lr_min,
224
+ verbose=True
225
+ )
226
+
227
+ training_set = dataset.make_training_set(
228
+ filelist=train_filelist,
229
+ sr=args.sr,
230
+ hop_length=args.hop_length,
231
+ n_fft=args.n_fft
232
+ )
233
+
234
+ train_dataset = dataset.VocalRemoverTrainingSet(
235
+ training_set * args.patches,
236
+ cropsize=args.cropsize,
237
+ reduction_rate=args.reduction_rate,
238
+ reduction_weight=reduction_weight,
239
+ mixup_rate=args.mixup_rate,
240
+ mixup_alpha=args.mixup_alpha
241
+ )
242
+
243
+ train_dataloader = torch.utils.data.DataLoader(
244
+ dataset=train_dataset,
245
+ batch_size=args.batchsize,
246
+ shuffle=True,
247
+ num_workers=args.num_workers
248
+ )
249
+
250
+ patch_list = dataset.make_validation_set(
251
+ filelist=val_filelist,
252
+ cropsize=args.val_cropsize,
253
+ sr=args.sr,
254
+ hop_length=args.hop_length,
255
+ n_fft=args.n_fft,
256
+ offset=model.offset
257
+ )
258
+
259
+ val_dataset = dataset.VocalRemoverValidationSet(
260
+ patch_list=patch_list
261
+ )
262
+
263
+ val_dataloader = torch.utils.data.DataLoader(
264
+ dataset=val_dataset,
265
+ batch_size=args.val_batchsize,
266
+ shuffle=False,
267
+ num_workers=args.num_workers
268
+ )
269
+
270
+ log = []
271
+ best_loss = np.inf
272
+ for epoch in range(args.epoch):
273
+ logger.info('# epoch {}'.format(epoch))
274
+ train_loss = train_epoch(train_dataloader, model, device, optimizer, args.accumulation_steps)
275
+ val_loss = validate_epoch(val_dataloader, model, device)
276
+
277
+ logger.info(
278
+ ' * training loss = {:.6f}, validation loss = {:.6f}'
279
+ .format(train_loss, val_loss)
280
+ )
281
+
282
+ scheduler.step(val_loss)
283
+
284
+ if val_loss < best_loss:
285
+ best_loss = val_loss
286
+ logger.info(' * best validation loss')
287
+ model_path = 'models/model_iter{}.pth'.format(epoch)
288
+ torch.save(model.state_dict(), model_path)
289
+
290
+ log.append([train_loss, val_loss])
291
+ with open('loss_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
292
+ json.dump(log, f, ensure_ascii=False)
293
+
294
+
295
+ if __name__ == '__main__':
296
+ timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
297
+ logger = setup_logger(__name__, 'train_{}.log'.format(timestamp))
298
+
299
+ try:
300
+ main()
301
+ except Exception as e:
302
+ logger.exception(e)