Spaces:
Runtime error
Runtime error
Merge pull request #24 from mhrice/new-metrics
Browse files- README.md +1 -1
- cfg/config.yaml +12 -7
- cfg/effects/all.yaml +29 -68
- cfg/effects/chorus.yaml +5 -18
- cfg/effects/compression.yaml +0 -22
- cfg/effects/compressor.yaml +9 -0
- cfg/effects/distortion.yaml +5 -12
- cfg/effects/reverb.yaml +11 -24
- cfg/exp/{demucs_compression.yaml → demucs_compressor.yaml} +1 -1
- cfg/exp/{umx_compression.yaml → umx_compressor.yaml} +1 -1
- remfx/datasets.py +11 -14
- remfx/models.py +28 -2
- remfx/utils.py +3 -4
- scripts/test.py +0 -1
- scripts/train.py +1 -0
README.md
CHANGED
@@ -22,7 +22,7 @@ Models and effects detailed below.
|
|
22 |
|
23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
24 |
|
25 |
-
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices
|
26 |
|
27 |
### Current Models
|
28 |
- `umx`
|
|
|
22 |
|
23 |
To add gpu, add `trainer.accelerator='gpu' trainer.devices=-1` to the command-line
|
24 |
|
25 |
+
Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' trainer.devices=1`
|
26 |
|
27 |
### Current Models
|
28 |
- `umx`
|
cfg/config.yaml
CHANGED
@@ -6,8 +6,8 @@ defaults:
|
|
6 |
seed: 12345
|
7 |
train: True
|
8 |
sample_rate: 48000
|
|
|
9 |
logs_dir: "./logs"
|
10 |
-
log_every_n_steps: 1000
|
11 |
render_files: True
|
12 |
render_root: "./data/processed"
|
13 |
|
@@ -21,6 +21,9 @@ callbacks:
|
|
21 |
verbose: False
|
22 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
|
|
|
|
|
|
24 |
|
25 |
datamodule:
|
26 |
_target_: remfx.datasets.VocalSetDatamodule
|
@@ -28,27 +31,27 @@ datamodule:
|
|
28 |
_target_: remfx.datasets.VocalSet
|
29 |
sample_rate: ${sample_rate}
|
30 |
root: ${oc.env:DATASET_ROOT}
|
31 |
-
|
32 |
mode: "train"
|
33 |
-
effect_types: ${effects
|
34 |
render_files: ${render_files}
|
35 |
render_root: ${render_root}
|
36 |
val_dataset:
|
37 |
_target_: remfx.datasets.VocalSet
|
38 |
sample_rate: ${sample_rate}
|
39 |
root: ${oc.env:DATASET_ROOT}
|
40 |
-
|
41 |
mode: "val"
|
42 |
-
effect_types: ${effects
|
43 |
render_files: ${render_files}
|
44 |
render_root: ${render_root}
|
45 |
test_dataset:
|
46 |
_target_: remfx.datasets.VocalSet
|
47 |
sample_rate: ${sample_rate}
|
48 |
root: ${oc.env:DATASET_ROOT}
|
49 |
-
|
50 |
mode: "test"
|
51 |
-
effect_types: ${effects
|
52 |
render_files: ${render_files}
|
53 |
render_root: ${render_root}
|
54 |
|
@@ -76,3 +79,5 @@ trainer:
|
|
76 |
accumulate_grad_batches: 1
|
77 |
accelerator: null
|
78 |
devices: 1
|
|
|
|
|
|
6 |
seed: 12345
|
7 |
train: True
|
8 |
sample_rate: 48000
|
9 |
+
chunk_size: 262144 # 5.5s
|
10 |
logs_dir: "./logs"
|
|
|
11 |
render_files: True
|
12 |
render_root: "./data/processed"
|
13 |
|
|
|
21 |
verbose: False
|
22 |
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
23 |
filename: '{epoch:02d}-{valid_loss:.3f}'
|
24 |
+
learning_rate_monitor:
|
25 |
+
_target_: pytorch_lightning.callbacks.LearningRateMonitor
|
26 |
+
logging_interval: "step"
|
27 |
|
28 |
datamodule:
|
29 |
_target_: remfx.datasets.VocalSetDatamodule
|
|
|
31 |
_target_: remfx.datasets.VocalSet
|
32 |
sample_rate: ${sample_rate}
|
33 |
root: ${oc.env:DATASET_ROOT}
|
34 |
+
chunk_size: ${chunk_size}
|
35 |
mode: "train"
|
36 |
+
effect_types: ${effects}
|
37 |
render_files: ${render_files}
|
38 |
render_root: ${render_root}
|
39 |
val_dataset:
|
40 |
_target_: remfx.datasets.VocalSet
|
41 |
sample_rate: ${sample_rate}
|
42 |
root: ${oc.env:DATASET_ROOT}
|
43 |
+
chunk_size: ${chunk_size}
|
44 |
mode: "val"
|
45 |
+
effect_types: ${effects}
|
46 |
render_files: ${render_files}
|
47 |
render_root: ${render_root}
|
48 |
test_dataset:
|
49 |
_target_: remfx.datasets.VocalSet
|
50 |
sample_rate: ${sample_rate}
|
51 |
root: ${oc.env:DATASET_ROOT}
|
52 |
+
chunk_size: ${chunk_size}
|
53 |
mode: "test"
|
54 |
+
effect_types: ${effects}
|
55 |
render_files: ${render_files}
|
56 |
render_root: ${render_root}
|
57 |
|
|
|
79 |
accumulate_grad_batches: 1
|
80 |
accelerator: null
|
81 |
devices: 1
|
82 |
+
gradient_clip_val: 10.0
|
83 |
+
max_steps: 50000
|
cfg/effects/all.yaml
CHANGED
@@ -1,70 +1,31 @@
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
_target_: remfx.effects.RandomPedalboardChorus
|
33 |
-
sample_rate: ${sample_rate}
|
34 |
-
min_rate_hz: 1.0
|
35 |
-
max_rate_hz: 1.0
|
36 |
-
min_depth: 0.3
|
37 |
-
max_depth: 0.3
|
38 |
-
min_centre_delay_ms: 7.5
|
39 |
-
max_centre_delay_ms: 7.5
|
40 |
-
min_feedback: 0.4
|
41 |
-
max_feedback: 0.4
|
42 |
-
min_mix: 0.4
|
43 |
-
max_mix: 0.4
|
44 |
-
Distortion:
|
45 |
-
_target_: remfx.effects.RandomPedalboardDistortion
|
46 |
-
sample_rate: ${sample_rate}
|
47 |
-
min_drive_db: 30
|
48 |
-
max_drive_db: 30
|
49 |
-
Compressor:
|
50 |
-
_target_: remfx.effects.RandomPedalboardCompressor
|
51 |
-
sample_rate: ${sample_rate}
|
52 |
-
min_threshold_db: -32
|
53 |
-
max_threshold_db: -32
|
54 |
-
min_ratio: 3.0
|
55 |
-
max_ratio: 3.0
|
56 |
-
min_attack_ms: 10.0
|
57 |
-
max_attack_ms: 10.0
|
58 |
-
min_release_ms: 40.0
|
59 |
-
max_release_ms: 40.0
|
60 |
-
Reverb:
|
61 |
-
_target_: remfx.effects.RandomPedalboardReverb
|
62 |
-
sample_rate: ${sample_rate}
|
63 |
-
min_room_size: 0.5
|
64 |
-
max_room_size: 0.5
|
65 |
-
min_damping: 0.5
|
66 |
-
max_damping: 0.5
|
67 |
-
min_wet_dry: 0.4
|
68 |
-
max_wet_dry: 0.4
|
69 |
-
min_width: 0.5
|
70 |
-
max_width: 0.5
|
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
+
Chorus:
|
4 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
5 |
+
sample_rate: ${sample_rate}
|
6 |
+
min_depth: 0.2
|
7 |
+
min_mix: 0.3
|
8 |
+
Distortion:
|
9 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
10 |
+
sample_rate: ${sample_rate}
|
11 |
+
min_drive_db: 10
|
12 |
+
max_drive_db: 50
|
13 |
+
Compressor:
|
14 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
15 |
+
sample_rate: ${sample_rate}
|
16 |
+
min_threshold_db: -42.0
|
17 |
+
max_threshold_db: -20.0
|
18 |
+
min_ratio: 1.5
|
19 |
+
max_ratio: 6.0
|
20 |
+
Reverb:
|
21 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
22 |
+
sample_rate: ${sample_rate}
|
23 |
+
min_room_size: 0.3
|
24 |
+
max_room_size: 1.0
|
25 |
+
min_damping: 0.2
|
26 |
+
max_damping: 1.0
|
27 |
+
min_wet_dry: 0.2
|
28 |
+
max_wet_dry: 0.8
|
29 |
+
min_width: 0.2
|
30 |
+
max_width: 1.0
|
31 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/chorus.yaml
CHANGED
@@ -1,20 +1,7 @@
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
Chorus:
|
9 |
-
_target_: remfx.effects.RandomPedalboardChorus
|
10 |
-
sample_rate: ${sample_rate}
|
11 |
-
min_rate_hz: 1.0
|
12 |
-
max_rate_hz: 1.0
|
13 |
-
min_depth: 0.3
|
14 |
-
max_depth: 0.3
|
15 |
-
min_centre_delay_ms: 7.5
|
16 |
-
max_centre_delay_ms: 7.5
|
17 |
-
min_feedback: 0.4
|
18 |
-
max_feedback: 0.4
|
19 |
-
min_mix: 0.4
|
20 |
-
max_mix: 0.4
|
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
+
Chorus:
|
4 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
5 |
+
sample_rate: ${sample_rate}
|
6 |
+
min_depth: 0.2
|
7 |
+
min_mix: 0.3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/compression.yaml
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
# @package _global_
|
2 |
-
effects:
|
3 |
-
train_effects:
|
4 |
-
Compressor:
|
5 |
-
_target_: remfx.effects.RandomPedalboardCompressor
|
6 |
-
sample_rate: ${sample_rate}
|
7 |
-
min_threshold_db: -42.0
|
8 |
-
max_threshold_db: -20.0
|
9 |
-
min_ratio: 1.5
|
10 |
-
max_ratio: 6.0
|
11 |
-
val_effects:
|
12 |
-
Compressor:
|
13 |
-
_target_: remfx.effects.RandomPedalboardCompressor
|
14 |
-
sample_rate: ${sample_rate}
|
15 |
-
min_threshold_db: -32
|
16 |
-
max_threshold_db: -32
|
17 |
-
min_ratio: 3.0
|
18 |
-
max_ratio: 3.0
|
19 |
-
min_attack_ms: 10.0
|
20 |
-
max_attack_ms: 10.0
|
21 |
-
min_release_ms: 40.0
|
22 |
-
max_release_ms: 40.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/compressor.yaml
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
Compressor:
|
4 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
5 |
+
sample_rate: ${sample_rate}
|
6 |
+
min_threshold_db: -42.0
|
7 |
+
max_threshold_db: -20.0
|
8 |
+
min_ratio: 1.5
|
9 |
+
max_ratio: 6.0
|
cfg/effects/distortion.yaml
CHANGED
@@ -1,14 +1,7 @@
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
max_drive_db: 50
|
9 |
-
val_effects:
|
10 |
-
Distortion:
|
11 |
-
_target_: remfx.effects.RandomPedalboardDistortion
|
12 |
-
sample_rate: ${sample_rate}
|
13 |
-
min_drive_db: 30
|
14 |
-
max_drive_db: 30
|
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
+
Distortion:
|
4 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
5 |
+
sample_rate: ${sample_rate}
|
6 |
+
min_drive_db: 10
|
7 |
+
max_drive_db: 50
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/reverb.yaml
CHANGED
@@ -1,26 +1,13 @@
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
max_width: 1.0
|
15 |
-
val_effects:
|
16 |
-
Reverb:
|
17 |
-
_target_: remfx.effects.RandomPedalboardReverb
|
18 |
-
sample_rate: ${sample_rate}
|
19 |
-
min_room_size: 0.5
|
20 |
-
max_room_size: 0.5
|
21 |
-
min_damping: 0.5
|
22 |
-
max_damping: 0.5
|
23 |
-
min_wet_dry: 0.4
|
24 |
-
max_wet_dry: 0.4
|
25 |
-
min_width: 0.5
|
26 |
-
max_width: 0.5
|
|
|
1 |
# @package _global_
|
2 |
effects:
|
3 |
+
Reverb:
|
4 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
5 |
+
sample_rate: ${sample_rate}
|
6 |
+
min_room_size: 0.3
|
7 |
+
max_room_size: 1.0
|
8 |
+
min_damping: 0.2
|
9 |
+
max_damping: 1.0
|
10 |
+
min_wet_dry: 0.2
|
11 |
+
max_wet_dry: 0.8
|
12 |
+
min_width: 0.2
|
13 |
+
max_width: 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/exp/{demucs_compression.yaml → demucs_compressor.yaml}
RENAMED
@@ -1,4 +1,4 @@
|
|
1 |
# @package _global_
|
2 |
defaults:
|
3 |
- override /model: demucs
|
4 |
-
- override /effects:
|
|
|
1 |
# @package _global_
|
2 |
defaults:
|
3 |
- override /model: demucs
|
4 |
+
- override /effects: compressor
|
cfg/exp/{umx_compression.yaml → umx_compressor.yaml}
RENAMED
@@ -1,4 +1,4 @@
|
|
1 |
# @package _global_
|
2 |
defaults:
|
3 |
- override /model: umx
|
4 |
-
- override /effects:
|
|
|
1 |
# @package _global_
|
2 |
defaults:
|
3 |
- override /model: umx
|
4 |
+
- override /effects: compressor
|
remfx/datasets.py
CHANGED
@@ -17,7 +17,7 @@ class VocalSet(Dataset):
|
|
17 |
self,
|
18 |
root: str,
|
19 |
sample_rate: int,
|
20 |
-
|
21 |
effect_types: List[torch.nn.Module] = None,
|
22 |
render_files: bool = True,
|
23 |
render_root: str = None,
|
@@ -28,7 +28,7 @@ class VocalSet(Dataset):
|
|
28 |
self.song_idx = []
|
29 |
self.root = Path(root)
|
30 |
self.render_root = Path(render_root)
|
31 |
-
self.
|
32 |
self.sample_rate = sample_rate
|
33 |
self.mode = mode
|
34 |
|
@@ -36,9 +36,11 @@ class VocalSet(Dataset):
|
|
36 |
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
37 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
38 |
self.effect_types = effect_types
|
39 |
-
|
40 |
-
self.processed_root = self.render_root / "processed" / self.mode
|
41 |
-
|
|
|
|
|
42 |
self.num_chunks = 0
|
43 |
print("Total files:", len(self.files))
|
44 |
print("Processing files...")
|
@@ -46,19 +48,14 @@ class VocalSet(Dataset):
|
|
46 |
# Split audio file into chunks, resample, then apply random effects
|
47 |
self.processed_root.mkdir(parents=True, exist_ok=True)
|
48 |
for audio_file in tqdm(self.files, total=len(self.files)):
|
49 |
-
chunks, orig_sr = create_sequential_chunks(
|
50 |
-
audio_file, self.chunk_size_in_sec
|
51 |
-
)
|
52 |
for chunk in chunks:
|
53 |
resampled_chunk = torchaudio.functional.resample(
|
54 |
chunk, orig_sr, sample_rate
|
55 |
)
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
resampled_chunk,
|
60 |
-
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
61 |
-
)
|
62 |
# Apply effect
|
63 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
64 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
|
|
17 |
self,
|
18 |
root: str,
|
19 |
sample_rate: int,
|
20 |
+
chunk_size: int = 3,
|
21 |
effect_types: List[torch.nn.Module] = None,
|
22 |
render_files: bool = True,
|
23 |
render_root: str = None,
|
|
|
28 |
self.song_idx = []
|
29 |
self.root = Path(root)
|
30 |
self.render_root = Path(render_root)
|
31 |
+
self.chunk_size = chunk_size
|
32 |
self.sample_rate = sample_rate
|
33 |
self.mode = mode
|
34 |
|
|
|
36 |
self.files = sorted(list(mode_path.glob("./**/*.wav")))
|
37 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
38 |
self.effect_types = effect_types
|
39 |
+
effect_str = "_".join([e for e in self.effect_types])
|
40 |
+
self.processed_root = self.render_root / "processed" / effect_str / self.mode
|
41 |
+
if self.processed_root.exists():
|
42 |
+
print("Found processed files.")
|
43 |
+
render_files = False
|
44 |
self.num_chunks = 0
|
45 |
print("Total files:", len(self.files))
|
46 |
print("Processing files...")
|
|
|
48 |
# Split audio file into chunks, resample, then apply random effects
|
49 |
self.processed_root.mkdir(parents=True, exist_ok=True)
|
50 |
for audio_file in tqdm(self.files, total=len(self.files)):
|
51 |
+
chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
|
|
|
|
|
52 |
for chunk in chunks:
|
53 |
resampled_chunk = torchaudio.functional.resample(
|
54 |
chunk, orig_sr, sample_rate
|
55 |
)
|
56 |
+
if resampled_chunk.shape[-1] < chunk_size:
|
57 |
+
# Skip if chunk is too small
|
58 |
+
continue
|
|
|
|
|
|
|
59 |
# Apply effect
|
60 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
61 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
remfx/models.py
CHANGED
@@ -55,6 +55,29 @@ class RemFXModel(pl.LightningModule):
|
|
55 |
)
|
56 |
return optimizer
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def training_step(self, batch, batch_idx):
|
59 |
loss = self.common_step(batch, batch_idx, mode="train")
|
60 |
return loss
|
@@ -215,7 +238,7 @@ class OpenUnmixModel(torch.nn.Module):
|
|
215 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
216 |
Y = self.model(X)
|
217 |
sep_out = self.separator(x).squeeze(1)
|
218 |
-
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target)
|
219 |
|
220 |
return loss, sep_out
|
221 |
|
@@ -236,7 +259,7 @@ class DemucsModel(torch.nn.Module):
|
|
236 |
def forward(self, batch):
|
237 |
x, target, label = batch
|
238 |
output = self.model(x).squeeze(1)
|
239 |
-
loss = self.mrstftloss(output, target) + self.l1loss(output, target)
|
240 |
return loss, output
|
241 |
|
242 |
def sample(self, x: Tensor) -> Tensor:
|
@@ -264,10 +287,13 @@ def log_wandb_audio_batch(
|
|
264 |
samples: Tensor,
|
265 |
sampling_rate: int,
|
266 |
caption: str = "",
|
|
|
267 |
):
|
268 |
num_items = samples.shape[0]
|
269 |
samples = rearrange(samples, "b c t -> b t c")
|
270 |
for idx in range(num_items):
|
|
|
|
|
271 |
logger.experiment.log(
|
272 |
{
|
273 |
f"{id}_{idx}": wandb.Audio(
|
|
|
55 |
)
|
56 |
return optimizer
|
57 |
|
58 |
+
# Add step-based learning rate scheduler
|
59 |
+
def optimizer_step(
|
60 |
+
self,
|
61 |
+
epoch,
|
62 |
+
batch_idx,
|
63 |
+
optimizer,
|
64 |
+
optimizer_idx,
|
65 |
+
optimizer_closure,
|
66 |
+
on_tpu,
|
67 |
+
using_native_amp,
|
68 |
+
using_lbfgs,
|
69 |
+
):
|
70 |
+
# update params
|
71 |
+
optimizer.step(closure=optimizer_closure)
|
72 |
+
|
73 |
+
# update learning rate. Reduce by factor of 10 at 80% and 95% of training
|
74 |
+
if self.trainer.global_step == 0.8 * self.trainer.max_steps:
|
75 |
+
for pg in optimizer.param_groups:
|
76 |
+
pg["lr"] = 0.1 * pg["lr"]
|
77 |
+
if self.trainer.global_step == 0.95 * self.trainer.max_steps:
|
78 |
+
for pg in optimizer.param_groups:
|
79 |
+
pg["lr"] = 0.1 * pg["lr"]
|
80 |
+
|
81 |
def training_step(self, batch, batch_idx):
|
82 |
loss = self.common_step(batch, batch_idx, mode="train")
|
83 |
return loss
|
|
|
238 |
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
239 |
Y = self.model(X)
|
240 |
sep_out = self.separator(x).squeeze(1)
|
241 |
+
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
|
242 |
|
243 |
return loss, sep_out
|
244 |
|
|
|
259 |
def forward(self, batch):
|
260 |
x, target, label = batch
|
261 |
output = self.model(x).squeeze(1)
|
262 |
+
loss = self.mrstftloss(output, target) + self.l1loss(output, target) * 100
|
263 |
return loss, output
|
264 |
|
265 |
def sample(self, x: Tensor) -> Tensor:
|
|
|
287 |
samples: Tensor,
|
288 |
sampling_rate: int,
|
289 |
caption: str = "",
|
290 |
+
max_items: int = 10,
|
291 |
):
|
292 |
num_items = samples.shape[0]
|
293 |
samples = rearrange(samples, "b c t -> b t c")
|
294 |
for idx in range(num_items):
|
295 |
+
if idx >= max_items:
|
296 |
+
break
|
297 |
logger.experiment.log(
|
298 |
{
|
299 |
f"{id}_{idx}": wandb.Audio(
|
remfx/utils.py
CHANGED
@@ -132,10 +132,9 @@ def create_sequential_chunks(
|
|
132 |
"""
|
133 |
chunks = []
|
134 |
audio, sr = torchaudio.load(audio_file)
|
135 |
-
|
136 |
-
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
137 |
for start in chunk_starts:
|
138 |
-
if start +
|
139 |
break
|
140 |
-
chunks.append(audio[:, start : start +
|
141 |
return chunks, sr
|
|
|
132 |
"""
|
133 |
chunks = []
|
134 |
audio, sr = torchaudio.load(audio_file)
|
135 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size)
|
|
|
136 |
for start in chunk_starts:
|
137 |
+
if start + chunk_size > audio.shape[-1]:
|
138 |
break
|
139 |
+
chunks.append(audio[:, start : start + chunk_size])
|
140 |
return chunks, sr
|
scripts/test.py
CHANGED
@@ -14,7 +14,6 @@ def main(cfg: DictConfig):
|
|
14 |
# Apply seed for reproducibility
|
15 |
if cfg.seed:
|
16 |
pl.seed_everything(cfg.seed)
|
17 |
-
cfg.render_files = False
|
18 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
19 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
20 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
|
|
14 |
# Apply seed for reproducibility
|
15 |
if cfg.seed:
|
16 |
pl.seed_everything(cfg.seed)
|
|
|
17 |
log.info(f"Instantiating datamodule <{cfg.datamodule._target_}>.")
|
18 |
datamodule = hydra.utils.instantiate(cfg.datamodule, _convert_="partial")
|
19 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
scripts/train.py
CHANGED
@@ -42,6 +42,7 @@ def main(cfg: DictConfig):
|
|
42 |
summary = ModelSummary(model)
|
43 |
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
|
|
45 |
|
46 |
|
47 |
if __name__ == "__main__":
|
|
|
42 |
summary = ModelSummary(model)
|
43 |
print(summary)
|
44 |
trainer.fit(model=model, datamodule=datamodule)
|
45 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path="best")
|
46 |
|
47 |
|
48 |
if __name__ == "__main__":
|