Upload train.py with huggingface_hub
Browse files
train.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from datetime import datetime
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.utils.data
|
12 |
+
|
13 |
+
from lib import dataset
|
14 |
+
from lib import nets
|
15 |
+
from lib import spec_utils
|
16 |
+
|
17 |
+
|
18 |
+
def setup_logger(name, logfile='LOGFILENAME.log'):
|
19 |
+
logger = logging.getLogger(name)
|
20 |
+
logger.setLevel(logging.DEBUG)
|
21 |
+
logger.propagate = False
|
22 |
+
|
23 |
+
fh = logging.FileHandler(logfile, encoding='utf8')
|
24 |
+
fh.setLevel(logging.DEBUG)
|
25 |
+
fh_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
26 |
+
fh.setFormatter(fh_formatter)
|
27 |
+
|
28 |
+
sh = logging.StreamHandler()
|
29 |
+
sh.setLevel(logging.INFO)
|
30 |
+
|
31 |
+
logger.addHandler(fh)
|
32 |
+
logger.addHandler(sh)
|
33 |
+
|
34 |
+
return logger
|
35 |
+
|
36 |
+
|
37 |
+
def to_wave(spec, n_fft, hop_length, window):
|
38 |
+
B, _, N, T = spec.shape
|
39 |
+
wave = spec.reshape(-1, N, T)
|
40 |
+
wave = torch.istft(wave, n_fft, hop_length, window=window)
|
41 |
+
wave = wave.reshape(B, 2, -1)
|
42 |
+
|
43 |
+
return wave
|
44 |
+
|
45 |
+
|
46 |
+
def sdr_loss(y, y_pred, eps=1e-8):
|
47 |
+
sdr = (y * y_pred).sum()
|
48 |
+
sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps
|
49 |
+
|
50 |
+
return -sdr
|
51 |
+
|
52 |
+
|
53 |
+
def weighted_sdr_loss(y, y_pred, n, n_pred, eps=1e-8):
|
54 |
+
y_sdr = (y * y_pred).sum()
|
55 |
+
y_sdr /= torch.linalg.norm(y) * torch.linalg.norm(y_pred) + eps
|
56 |
+
|
57 |
+
noise_sdr = (n * n_pred).sum()
|
58 |
+
noise_sdr /= torch.linalg.norm(n) * torch.linalg.norm(n_pred) + eps
|
59 |
+
|
60 |
+
a = torch.sum(y ** 2)
|
61 |
+
a /= torch.sum(y ** 2) + torch.sum(n ** 2) + eps
|
62 |
+
|
63 |
+
loss = a * y_sdr + (1 - a) * noise_sdr
|
64 |
+
|
65 |
+
return -loss
|
66 |
+
|
67 |
+
|
68 |
+
def train_epoch(dataloader, model, device, optimizer, accumulation_steps):
|
69 |
+
model.train()
|
70 |
+
n_fft = model.n_fft
|
71 |
+
hop_length = model.hop_length
|
72 |
+
window = torch.hann_window(n_fft).to(device)
|
73 |
+
|
74 |
+
sum_loss = 0
|
75 |
+
crit_l1 = nn.L1Loss()
|
76 |
+
|
77 |
+
for itr, (X_batch, y_batch) in enumerate(dataloader):
|
78 |
+
X_batch = X_batch.to(device)
|
79 |
+
y_batch = y_batch.to(device)
|
80 |
+
|
81 |
+
mask = model(X_batch)
|
82 |
+
|
83 |
+
y_pred = X_batch * mask
|
84 |
+
y_wave_batch = to_wave(y_batch, n_fft, hop_length, window)
|
85 |
+
y_wave_pred = to_wave(y_pred, n_fft, hop_length, window)
|
86 |
+
|
87 |
+
loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred))
|
88 |
+
loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01
|
89 |
+
|
90 |
+
accum_loss = loss / accumulation_steps
|
91 |
+
accum_loss.backward()
|
92 |
+
|
93 |
+
if (itr + 1) % accumulation_steps == 0:
|
94 |
+
optimizer.step()
|
95 |
+
model.zero_grad()
|
96 |
+
|
97 |
+
sum_loss += loss.item() * len(X_batch)
|
98 |
+
|
99 |
+
# the rest batch
|
100 |
+
if (itr + 1) % accumulation_steps != 0:
|
101 |
+
optimizer.step()
|
102 |
+
model.zero_grad()
|
103 |
+
|
104 |
+
return sum_loss / len(dataloader.dataset)
|
105 |
+
|
106 |
+
|
107 |
+
def validate_epoch(dataloader, model, device):
|
108 |
+
model.eval()
|
109 |
+
n_fft = model.n_fft
|
110 |
+
hop_length = model.hop_length
|
111 |
+
window = torch.hann_window(n_fft).to(device)
|
112 |
+
|
113 |
+
sum_loss = 0
|
114 |
+
crit_l1 = nn.L1Loss()
|
115 |
+
|
116 |
+
with torch.no_grad():
|
117 |
+
for X_batch, y_batch in dataloader:
|
118 |
+
X_batch = X_batch.to(device)
|
119 |
+
y_batch = y_batch.to(device)
|
120 |
+
|
121 |
+
y_pred = model.predict(X_batch)
|
122 |
+
|
123 |
+
y_batch = spec_utils.crop_center(y_batch, y_pred)
|
124 |
+
y_wave_batch = to_wave(y_batch, n_fft, hop_length, window)
|
125 |
+
y_wave_pred = to_wave(y_pred, n_fft, hop_length, window)
|
126 |
+
|
127 |
+
loss = crit_l1(torch.abs(y_batch), torch.abs(y_pred))
|
128 |
+
loss += sdr_loss(y_wave_batch, y_wave_pred) * 0.01
|
129 |
+
|
130 |
+
sum_loss += loss.item() * len(X_batch)
|
131 |
+
|
132 |
+
return sum_loss / len(dataloader.dataset)
|
133 |
+
|
134 |
+
|
135 |
+
def main():
|
136 |
+
p = argparse.ArgumentParser()
|
137 |
+
p.add_argument('--gpu', '-g', type=int, default=-1)
|
138 |
+
p.add_argument('--seed', '-s', type=int, default=2019)
|
139 |
+
p.add_argument('--sr', '-r', type=int, default=44100)
|
140 |
+
p.add_argument('--hop_length', '-H', type=int, default=1024)
|
141 |
+
p.add_argument('--n_fft', '-f', type=int, default=2048)
|
142 |
+
p.add_argument('--dataset', '-d', required=True)
|
143 |
+
p.add_argument('--split_mode', '-S', type=str, choices=['random', 'subdirs'], default='random')
|
144 |
+
p.add_argument('--learning_rate', '-l', type=float, default=0.001)
|
145 |
+
p.add_argument('--lr_min', type=float, default=0.0001)
|
146 |
+
p.add_argument('--lr_decay_factor', type=float, default=0.9)
|
147 |
+
p.add_argument('--lr_decay_patience', type=int, default=6)
|
148 |
+
p.add_argument('--batchsize', '-B', type=int, default=4)
|
149 |
+
p.add_argument('--accumulation_steps', '-A', type=int, default=1)
|
150 |
+
p.add_argument('--cropsize', '-C', type=int, default=256)
|
151 |
+
p.add_argument('--patches', '-p', type=int, default=16)
|
152 |
+
p.add_argument('--val_rate', '-v', type=float, default=0.2)
|
153 |
+
p.add_argument('--val_filelist', '-V', type=str, default=None)
|
154 |
+
p.add_argument('--val_batchsize', '-b', type=int, default=4)
|
155 |
+
p.add_argument('--val_cropsize', '-c', type=int, default=256)
|
156 |
+
p.add_argument('--num_workers', '-w', type=int, default=4)
|
157 |
+
p.add_argument('--epoch', '-E', type=int, default=200)
|
158 |
+
p.add_argument('--reduction_rate', '-R', type=float, default=0.0)
|
159 |
+
p.add_argument('--reduction_level', '-L', type=float, default=0.2)
|
160 |
+
p.add_argument('--mixup_rate', '-M', type=float, default=0.0)
|
161 |
+
p.add_argument('--mixup_alpha', '-a', type=float, default=1.0)
|
162 |
+
p.add_argument('--pretrained_model', '-P', type=str, default=None)
|
163 |
+
p.add_argument('--debug', action='store_true')
|
164 |
+
args = p.parse_args()
|
165 |
+
|
166 |
+
logger.debug(vars(args))
|
167 |
+
|
168 |
+
random.seed(args.seed)
|
169 |
+
np.random.seed(args.seed)
|
170 |
+
torch.manual_seed(args.seed)
|
171 |
+
|
172 |
+
val_filelist = []
|
173 |
+
if args.val_filelist is not None:
|
174 |
+
with open(args.val_filelist, 'r', encoding='utf8') as f:
|
175 |
+
val_filelist = json.load(f)
|
176 |
+
|
177 |
+
train_filelist, val_filelist = dataset.train_val_split(
|
178 |
+
dataset_dir=args.dataset,
|
179 |
+
split_mode=args.split_mode,
|
180 |
+
val_rate=args.val_rate,
|
181 |
+
val_filelist=val_filelist
|
182 |
+
)
|
183 |
+
|
184 |
+
if args.debug:
|
185 |
+
logger.info('### DEBUG MODE')
|
186 |
+
train_filelist = train_filelist[:1]
|
187 |
+
val_filelist = val_filelist[:1]
|
188 |
+
elif args.val_filelist is None and args.split_mode == 'random':
|
189 |
+
with open('val_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
|
190 |
+
json.dump(val_filelist, f, ensure_ascii=False)
|
191 |
+
|
192 |
+
for i, (X_fname, y_fname) in enumerate(val_filelist):
|
193 |
+
logger.info('{} {} {}'.format(i + 1, os.path.basename(X_fname), os.path.basename(y_fname)))
|
194 |
+
|
195 |
+
bins = args.n_fft // 2 + 1
|
196 |
+
freq_to_bin = 2 * bins / args.sr
|
197 |
+
unstable_bins = int(200 * freq_to_bin)
|
198 |
+
stable_bins = int(22050 * freq_to_bin)
|
199 |
+
reduction_weight = np.concatenate([
|
200 |
+
np.linspace(0, 1, unstable_bins, dtype=np.float32)[:, None],
|
201 |
+
np.linspace(1, 0, stable_bins - unstable_bins, dtype=np.float32)[:, None],
|
202 |
+
np.zeros((bins - stable_bins, 1), dtype=np.float32),
|
203 |
+
], axis=0) * args.reduction_level
|
204 |
+
|
205 |
+
device = torch.device('cpu')
|
206 |
+
model = nets.CascadedNet(args.n_fft, args.hop_length, 32, 128, True)
|
207 |
+
if args.pretrained_model is not None:
|
208 |
+
model.load_state_dict(torch.load(args.pretrained_model, map_location=device))
|
209 |
+
if torch.cuda.is_available() and args.gpu >= 0:
|
210 |
+
device = torch.device('cuda:{}'.format(args.gpu))
|
211 |
+
model.to(device)
|
212 |
+
|
213 |
+
optimizer = torch.optim.Adam(
|
214 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
215 |
+
lr=args.learning_rate
|
216 |
+
)
|
217 |
+
|
218 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
219 |
+
optimizer,
|
220 |
+
factor=args.lr_decay_factor,
|
221 |
+
patience=args.lr_decay_patience,
|
222 |
+
threshold=1e-6,
|
223 |
+
min_lr=args.lr_min,
|
224 |
+
verbose=True
|
225 |
+
)
|
226 |
+
|
227 |
+
training_set = dataset.make_training_set(
|
228 |
+
filelist=train_filelist,
|
229 |
+
sr=args.sr,
|
230 |
+
hop_length=args.hop_length,
|
231 |
+
n_fft=args.n_fft
|
232 |
+
)
|
233 |
+
|
234 |
+
train_dataset = dataset.VocalRemoverTrainingSet(
|
235 |
+
training_set * args.patches,
|
236 |
+
cropsize=args.cropsize,
|
237 |
+
reduction_rate=args.reduction_rate,
|
238 |
+
reduction_weight=reduction_weight,
|
239 |
+
mixup_rate=args.mixup_rate,
|
240 |
+
mixup_alpha=args.mixup_alpha
|
241 |
+
)
|
242 |
+
|
243 |
+
train_dataloader = torch.utils.data.DataLoader(
|
244 |
+
dataset=train_dataset,
|
245 |
+
batch_size=args.batchsize,
|
246 |
+
shuffle=True,
|
247 |
+
num_workers=args.num_workers
|
248 |
+
)
|
249 |
+
|
250 |
+
patch_list = dataset.make_validation_set(
|
251 |
+
filelist=val_filelist,
|
252 |
+
cropsize=args.val_cropsize,
|
253 |
+
sr=args.sr,
|
254 |
+
hop_length=args.hop_length,
|
255 |
+
n_fft=args.n_fft,
|
256 |
+
offset=model.offset
|
257 |
+
)
|
258 |
+
|
259 |
+
val_dataset = dataset.VocalRemoverValidationSet(
|
260 |
+
patch_list=patch_list
|
261 |
+
)
|
262 |
+
|
263 |
+
val_dataloader = torch.utils.data.DataLoader(
|
264 |
+
dataset=val_dataset,
|
265 |
+
batch_size=args.val_batchsize,
|
266 |
+
shuffle=False,
|
267 |
+
num_workers=args.num_workers
|
268 |
+
)
|
269 |
+
|
270 |
+
log = []
|
271 |
+
best_loss = np.inf
|
272 |
+
for epoch in range(args.epoch):
|
273 |
+
logger.info('# epoch {}'.format(epoch))
|
274 |
+
train_loss = train_epoch(train_dataloader, model, device, optimizer, args.accumulation_steps)
|
275 |
+
val_loss = validate_epoch(val_dataloader, model, device)
|
276 |
+
|
277 |
+
logger.info(
|
278 |
+
' * training loss = {:.6f}, validation loss = {:.6f}'
|
279 |
+
.format(train_loss, val_loss)
|
280 |
+
)
|
281 |
+
|
282 |
+
scheduler.step(val_loss)
|
283 |
+
|
284 |
+
if val_loss < best_loss:
|
285 |
+
best_loss = val_loss
|
286 |
+
logger.info(' * best validation loss')
|
287 |
+
model_path = 'models/model_iter{}.pth'.format(epoch)
|
288 |
+
torch.save(model.state_dict(), model_path)
|
289 |
+
|
290 |
+
log.append([train_loss, val_loss])
|
291 |
+
with open('loss_{}.json'.format(timestamp), 'w', encoding='utf8') as f:
|
292 |
+
json.dump(log, f, ensure_ascii=False)
|
293 |
+
|
294 |
+
|
295 |
+
if __name__ == '__main__':
|
296 |
+
timestamp = datetime.now().strftime('%Y%m%d%H%M%S')
|
297 |
+
logger = setup_logger(__name__, 'train_{}.log'.format(timestamp))
|
298 |
+
|
299 |
+
try:
|
300 |
+
main()
|
301 |
+
except Exception as e:
|
302 |
+
logger.exception(e)
|