Spaces:
Sleeping
Sleeping
Commit
·
d8d3e30
1
Parent(s):
4a7a6b8
Add gradient clipping and lr scheduler
Browse files- cfg/config.yaml +5 -1
- remfx/datasets.py +4 -7
- remfx/models.py +22 -0
- remfx/utils.py +3 -4
- scripts/train.py +1 -0
cfg/config.yaml
CHANGED
@@ -8,7 +8,6 @@ train: True
|
|
8 |
sample_rate: 48000
|
9 |
chunk_size: 262144 # 5.5s
|
10 |
logs_dir: "./logs"
|
11 |
-
log_every_n_steps: 1000
|
12 |
render_files: True
|
13 |
render_root: "./data/processed"
|
14 |
|
@@ -22,6 +21,9 @@ callbacks:
|
|
22 |
verbose: False
|
23 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
24 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
|
|
|
|
|
|
25 |
|
26 |
datamodule:
|
27 |
_target_: remfx.datasets.VocalSetDatamodule
|
@@ -77,3 +79,5 @@ trainer:
|
|
77 |
accumulate_grad_batches: 1
|
78 |
accelerator: null
|
79 |
devices: 1
|
|
|
|
|
|
8 |
sample_rate: 48000
|
9 |
chunk_size: 262144 # 5.5s
|
10 |
logs_dir: "./logs"
|
|
|
11 |
render_files: True
|
12 |
render_root: "./data/processed"
|
13 |
|
|
|
21 |
verbose: False
|
22 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
24 |
+
learning_rate_monitor:
|
25 |
+
_target_: pytorch_lightning.callbacks.LearningRateMonitor
|
26 |
+
logging_interval: "step"
|
27 |
|
28 |
datamodule:
|
29 |
_target_: remfx.datasets.VocalSetDatamodule
|
|
|
79 |
accumulate_grad_batches: 1
|
80 |
accelerator: null
|
81 |
devices: 1
|
82 |
+
gradient_clip_val: 10.0
|
83 |
+
max_steps: 50000
|
remfx/datasets.py
CHANGED
@@ -17,7 +17,7 @@ class VocalSet(Dataset):
|
|
17 |
self,
|
18 |
root: str,
|
19 |
sample_rate: int,
|
20 |
-
|
21 |
effect_types: List[torch.nn.Module] = None,
|
22 |
render_files: bool = True,
|
23 |
render_root: str = None,
|
@@ -28,7 +28,7 @@ class VocalSet(Dataset):
|
|
28 |
self.song_idx = []
|
29 |
self.root = Path(root)
|
30 |
self.render_root = Path(render_root)
|
31 |
-
self.
|
32 |
self.sample_rate = sample_rate
|
33 |
self.mode = mode
|
34 |
|
@@ -46,15 +46,12 @@ class VocalSet(Dataset):
|
|
46 |
# Split audio file into chunks, resample, then apply random effects
|
47 |
self.processed_root.mkdir(parents=True, exist_ok=True)
|
48 |
for audio_file in tqdm(self.files, total=len(self.files)):
|
49 |
-
chunks, orig_sr = create_sequential_chunks(
|
50 |
-
audio_file, self.chunk_size_in_sec
|
51 |
-
)
|
52 |
for chunk in chunks:
|
53 |
resampled_chunk = torchaudio.functional.resample(
|
54 |
chunk, orig_sr, sample_rate
|
55 |
)
|
56 |
-
|
57 |
-
if resampled_chunk.shape[-1] < chunk_size_in_samples:
|
58 |
# Skip if chunk is too small
|
59 |
continue
|
60 |
# Apply effect
|
|
|
17 |
self,
|
18 |
root: str,
|
19 |
sample_rate: int,
|
20 |
+
chunk_size: int = 3,
|
21 |
effect_types: List[torch.nn.Module] = None,
|
22 |
render_files: bool = True,
|
23 |
render_root: str = None,
|
|
|
28 |
self.song_idx = []
|
29 |
self.root = Path(root)
|
30 |
self.render_root = Path(render_root)
|
31 |
+
self.chunk_size = chunk_size
|
32 |
self.sample_rate = sample_rate
|
33 |
self.mode = mode
|
34 |
|
|
|
46 |
# Split audio file into chunks, resample, then apply random effects
|
47 |
self.processed_root.mkdir(parents=True, exist_ok=True)
|
48 |
for audio_file in tqdm(self.files, total=len(self.files)):
|
49 |
+
chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
|
|
|
|
|
50 |
for chunk in chunks:
|
51 |
resampled_chunk = torchaudio.functional.resample(
|
52 |
chunk, orig_sr, sample_rate
|
53 |
)
|
54 |
+
if resampled_chunk.shape[-1] < chunk_size:
|
|
|
55 |
# Skip if chunk is too small
|
56 |
continue
|
57 |
# Apply effect
|
remfx/models.py
CHANGED
@@ -55,6 +55,28 @@ class RemFXModel(pl.LightningModule):
|
|
55 |
)
|
56 |
return optimizer
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def training_step(self, batch, batch_idx):
|
59 |
loss = self.common_step(batch, batch_idx, mode="train")
|
60 |
return loss
|
|
|
55 |
)
|
56 |
return optimizer
|
57 |
|
58 |
+
# Add step-based learning rate scheduler
|
59 |
+
def optimizer_step(
|
60 |
+
self,
|
61 |
+
epoch,
|
62 |
+
batch_idx,
|
63 |
+
optimizer,
|
64 |
+
optimizer_idx,
|
65 |
+
optimizer_closure,
|
66 |
+
on_tpu=False,
|
67 |
+
using_lbfgs=False,
|
68 |
+
):
|
69 |
+
# update params
|
70 |
+
optimizer.step(closure=optimizer_closure)
|
71 |
+
|
72 |
+
# update learning rate. Reduce by factor of 10 at 80% and 95% of training
|
73 |
+
if self.trainer.global_step == 0.8 * self.trainer.max_steps:
|
74 |
+
for pg in optimizer.param_groups:
|
75 |
+
pg["lr"] = 0.1 * pg["lr"]
|
76 |
+
if self.trainer.global_step == 0.95 * self.trainer.max_steps:
|
77 |
+
for pg in optimizer.param_groups:
|
78 |
+
pg["lr"] = 0.1 * pg["lr"]
|
79 |
+
|
80 |
def training_step(self, batch, batch_idx):
|
81 |
loss = self.common_step(batch, batch_idx, mode="train")
|
82 |
return loss
|
remfx/utils.py
CHANGED
@@ -132,10 +132,9 @@ def create_sequential_chunks(
|
|
132 |
"""
|
133 |
chunks = []
|
134 |
audio, sr = torchaudio.load(audio_file)
|
135 |
-
|
136 |
-
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
137 |
for start in chunk_starts:
|
138 |
-
if start +
|
139 |
break
|
140 |
-
chunks.append(audio[:, start : start +
|
141 |
return chunks, sr
|
|
|
132 |
"""
|
133 |
chunks = []
|
134 |
audio, sr = torchaudio.load(audio_file)
|
135 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size)
|
|
|
136 |
for start in chunk_starts:
|
137 |
+
if start + chunk_size > audio.shape[-1]:
|
138 |
break
|
139 |
+
chunks.append(audio[:, start : start + chunk_size])
|
140 |
return chunks, sr
|
scripts/train.py
CHANGED
@@ -42,6 +42,7 @@ def main(cfg: DictConfig):
|
|
42 |
summary = ModelSummary(model)
|
43 |
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
|
|
45 |
|
46 |
|
47 |
if __name__ == "__main__":
|
|
|
42 |
summary = ModelSummary(model)
|
43 |
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
45 |
+
trainer.test(model=model, datamodule=datamodule)
|
46 |
|
47 |
|
48 |
if __name__ == "__main__":
|