Spaces:
Sleeping
Sleeping
Commit
·
8125531
1
Parent(s):
7ed6389
Clean project. Add 'all effects' to experiments
Browse files- README.md +8 -2
- cfg/config.yaml +12 -2
- cfg/config_guitarset.yaml +0 -52
- cfg/config_guitfx.yaml +0 -52
- cfg/effects/all.yaml +70 -0
- cfg/exp/demucs_all.yaml +4 -0
- cfg/exp/umx_all.yaml +4 -0
- remfx/datasets.py +27 -259
- remfx/models.py +44 -72
- remfx/utils.py +71 -1
- shell_vars.sh +0 -1
README.md
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
4. Manually split singers into train, val, test directories
|
14 |
|
15 |
## Train model
|
16 |
-
1. Change Wandb variables in `shell_vars.sh` and `source shell_vars.sh`
|
17 |
2. `python scripts/train.py +exp=umx_distortion`
|
18 |
or
|
19 |
2. `python scripts/train.py +exp=demucs_distortion`
|
@@ -33,6 +33,12 @@ Ex. `python scripts/train.py +exp=umx_distortion trainer.accelerator='gpu' train
|
|
33 |
- `compressor`
|
34 |
- `distortion`
|
35 |
- `reverb`
|
|
|
36 |
|
37 |
## Misc.
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
13 |
4. Manually split singers into train, val, test directories
|
14 |
|
15 |
## Train model
|
16 |
+
1. Change Wandb and data root variables in `shell_vars.sh` and `source shell_vars.sh`
|
17 |
2. `python scripts/train.py +exp=umx_distortion`
|
18 |
or
|
19 |
2. `python scripts/train.py +exp=demucs_distortion`
|
|
|
33 |
- `compressor`
|
34 |
- `distortion`
|
35 |
- `reverb`
|
36 |
+
- `all` (choose random effect to apply to each file)
|
37 |
|
38 |
## Misc.
|
39 |
+
By default, files are rendered to `input_dir / processed / train/val/test`.
|
40 |
+
To skip rendering files (use previously rendered), add `render_files=False` to the command-line
|
41 |
+
|
42 |
+
Test
|
43 |
+
Experiment dictates data, ckpt dictates model
|
44 |
+
`python scripts/test.py +exp=umx_distortion.yaml +ckpt_path=test_ckpts/umx_dist.ckpt`
|
cfg/config.yaml
CHANGED
@@ -8,6 +8,7 @@ train: True
|
|
8 |
sample_rate: 48000
|
9 |
logs_dir: "./logs"
|
10 |
log_every_n_steps: 1000
|
|
|
11 |
|
12 |
callbacks:
|
13 |
model_checkpoint:
|
@@ -26,18 +27,27 @@ datamodule:
|
|
26 |
_target_: remfx.datasets.VocalSet
|
27 |
sample_rate: ${sample_rate}
|
28 |
root: ${oc.env:DATASET_ROOT}
|
29 |
-
output_root: ${oc.env:OUTPUT_ROOT}/train
|
30 |
chunk_size_in_sec: 6
|
31 |
mode: "train"
|
32 |
effect_types: ${effects.train_effects}
|
|
|
33 |
val_dataset:
|
34 |
_target_: remfx.datasets.VocalSet
|
35 |
sample_rate: ${sample_rate}
|
36 |
root: ${oc.env:DATASET_ROOT}
|
37 |
-
output_root: ${oc.env:OUTPUT_ROOT}/val
|
38 |
chunk_size_in_sec: 6
|
39 |
mode: "val"
|
40 |
effect_types: ${effects.val_effects}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
batch_size: 16
|
42 |
num_workers: 8
|
43 |
pin_memory: True
|
|
|
8 |
sample_rate: 48000
|
9 |
logs_dir: "./logs"
|
10 |
log_every_n_steps: 1000
|
11 |
+
render_files: True
|
12 |
|
13 |
callbacks:
|
14 |
model_checkpoint:
|
|
|
27 |
_target_: remfx.datasets.VocalSet
|
28 |
sample_rate: ${sample_rate}
|
29 |
root: ${oc.env:DATASET_ROOT}
|
|
|
30 |
chunk_size_in_sec: 6
|
31 |
mode: "train"
|
32 |
effect_types: ${effects.train_effects}
|
33 |
+
render_files: ${render_files}
|
34 |
val_dataset:
|
35 |
_target_: remfx.datasets.VocalSet
|
36 |
sample_rate: ${sample_rate}
|
37 |
root: ${oc.env:DATASET_ROOT}
|
|
|
38 |
chunk_size_in_sec: 6
|
39 |
mode: "val"
|
40 |
effect_types: ${effects.val_effects}
|
41 |
+
render_files: ${render_files}
|
42 |
+
test_dataset:
|
43 |
+
_target_: remfx.datasets.VocalSet
|
44 |
+
sample_rate: ${sample_rate}
|
45 |
+
root: ${oc.env:DATASET_ROOT}
|
46 |
+
chunk_size_in_sec: 6
|
47 |
+
mode: "test"
|
48 |
+
effect_types: ${effects.val_effects}
|
49 |
+
render_files: ${render_files}
|
50 |
+
|
51 |
batch_size: 16
|
52 |
num_workers: 8
|
53 |
pin_memory: True
|
cfg/config_guitarset.yaml
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- _self_
|
3 |
-
- exp: null
|
4 |
-
seed: 12345
|
5 |
-
train: True
|
6 |
-
sample_rate: 48000
|
7 |
-
logs_dir: "./logs"
|
8 |
-
log_every_n_steps: 1000
|
9 |
-
|
10 |
-
callbacks:
|
11 |
-
model_checkpoint:
|
12 |
-
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
13 |
-
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
14 |
-
save_top_k: 1 # save k best models (determined by above metric)
|
15 |
-
save_last: True # additionaly always save model from last epoch
|
16 |
-
mode: "min" # can be "max" or "min"
|
17 |
-
verbose: False
|
18 |
-
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
19 |
-
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
-
|
21 |
-
datamodule:
|
22 |
-
_target_: remfx.datasets.Datamodule
|
23 |
-
dataset:
|
24 |
-
_target_: remfx.datasets.GuitarSet
|
25 |
-
sample_rate: ${sample_rate}
|
26 |
-
root: ${oc.env:DATASET_ROOT}
|
27 |
-
chunk_size_in_sec: 6
|
28 |
-
val_split: 0.2
|
29 |
-
batch_size: 16
|
30 |
-
num_workers: 8
|
31 |
-
pin_memory: True
|
32 |
-
persistent_workers: True
|
33 |
-
|
34 |
-
logger:
|
35 |
-
_target_: pytorch_lightning.loggers.WandbLogger
|
36 |
-
project: ${oc.env:WANDB_PROJECT}
|
37 |
-
entity: ${oc.env:WANDB_ENTITY}
|
38 |
-
# offline: False # set True to store all logs only locally
|
39 |
-
job_type: "train"
|
40 |
-
group: ""
|
41 |
-
save_dir: "."
|
42 |
-
|
43 |
-
trainer:
|
44 |
-
_target_: pytorch_lightning.Trainer
|
45 |
-
precision: 32 # Precision used for tensors, default `32`
|
46 |
-
min_epochs: 0
|
47 |
-
max_epochs: -1
|
48 |
-
enable_model_summary: False
|
49 |
-
log_every_n_steps: 1 # Logs metrics every N batches
|
50 |
-
accumulate_grad_batches: 1
|
51 |
-
accelerator: null
|
52 |
-
devices: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/config_guitfx.yaml
DELETED
@@ -1,52 +0,0 @@
|
|
1 |
-
defaults:
|
2 |
-
- _self_
|
3 |
-
- exp: null
|
4 |
-
seed: 12345
|
5 |
-
train: True
|
6 |
-
sample_rate: 48000
|
7 |
-
logs_dir: "./logs"
|
8 |
-
log_every_n_steps: 1000
|
9 |
-
|
10 |
-
callbacks:
|
11 |
-
model_checkpoint:
|
12 |
-
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
13 |
-
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
14 |
-
save_top_k: 1 # save k best models (determined by above metric)
|
15 |
-
save_last: True # additionaly always save model from last epoch
|
16 |
-
mode: "min" # can be "max" or "min"
|
17 |
-
verbose: False
|
18 |
-
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
19 |
-
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
-
|
21 |
-
datamodule:
|
22 |
-
_target_: remfx.datasets.Datamodule
|
23 |
-
dataset:
|
24 |
-
_target_: remfx.datasets.GuitarFXDataset
|
25 |
-
sample_rate: ${sample_rate}
|
26 |
-
root: ${oc.env:DATASET_ROOT}
|
27 |
-
chunk_size_in_sec: 6
|
28 |
-
val_split: 0.2
|
29 |
-
batch_size: 16
|
30 |
-
num_workers: 8
|
31 |
-
pin_memory: True
|
32 |
-
persistent_workers: True
|
33 |
-
|
34 |
-
logger:
|
35 |
-
_target_: pytorch_lightning.loggers.WandbLogger
|
36 |
-
project: ${oc.env:WANDB_PROJECT}
|
37 |
-
entity: ${oc.env:WANDB_ENTITY}
|
38 |
-
# offline: False # set True to store all logs only locally
|
39 |
-
job_type: "train"
|
40 |
-
group: ""
|
41 |
-
save_dir: "."
|
42 |
-
|
43 |
-
trainer:
|
44 |
-
_target_: pytorch_lightning.Trainer
|
45 |
-
precision: 32 # Precision used for tensors, default `32`
|
46 |
-
min_epochs: 0
|
47 |
-
max_epochs: -1
|
48 |
-
enable_model_summary: False
|
49 |
-
log_every_n_steps: 1 # Logs metrics every N batches
|
50 |
-
accumulate_grad_batches: 1
|
51 |
-
accelerator: null
|
52 |
-
devices: 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg/effects/all.yaml
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
effects:
|
3 |
+
train_effects:
|
4 |
+
Chorus:
|
5 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
6 |
+
sample_rate: ${sample_rate}
|
7 |
+
Distortion:
|
8 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
9 |
+
sample_rate: ${sample_rate}
|
10 |
+
min_drive_db: -10
|
11 |
+
max_drive_db: 50
|
12 |
+
Compressor:
|
13 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
14 |
+
sample_rate: ${sample_rate}
|
15 |
+
min_threshold_db: -42.0
|
16 |
+
max_threshold_db: -20.0
|
17 |
+
min_ratio: 1.5
|
18 |
+
max_ratio: 6.0
|
19 |
+
Reverb:
|
20 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
21 |
+
sample_rate: ${sample_rate}
|
22 |
+
min_room_size: 0.3
|
23 |
+
max_room_size: 1.0
|
24 |
+
min_damping: 0.2
|
25 |
+
max_damping: 1.0
|
26 |
+
min_wet_dry: 0.2
|
27 |
+
max_wet_dry: 0.8
|
28 |
+
min_width: 0.2
|
29 |
+
max_width: 1.0
|
30 |
+
val_effects:
|
31 |
+
Chorus:
|
32 |
+
_target_: remfx.effects.RandomPedalboardChorus
|
33 |
+
sample_rate: ${sample_rate}
|
34 |
+
min_rate_hz: 1.0
|
35 |
+
max_rate_hz: 1.0
|
36 |
+
min_depth: 0.3
|
37 |
+
max_depth: 0.3
|
38 |
+
min_centre_delay_ms: 7.5
|
39 |
+
max_centre_delay_ms: 7.5
|
40 |
+
min_feedback: 0.4
|
41 |
+
max_feedback: 0.4
|
42 |
+
min_mix: 0.4
|
43 |
+
max_mix: 0.4
|
44 |
+
Distortion:
|
45 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
46 |
+
sample_rate: ${sample_rate}
|
47 |
+
min_drive_db: 30
|
48 |
+
max_drive_db: 30
|
49 |
+
Compressor:
|
50 |
+
_target_: remfx.effects.RandomPedalboardCompressor
|
51 |
+
sample_rate: ${sample_rate}
|
52 |
+
min_threshold_db: -32
|
53 |
+
max_threshold_db: -32
|
54 |
+
min_ratio: 3.0
|
55 |
+
max_ratio: 3.0
|
56 |
+
min_attack_ms: 10.0
|
57 |
+
max_attack_ms: 10.0
|
58 |
+
min_release_ms: 40.0
|
59 |
+
max_release_ms: 40.0
|
60 |
+
Reverb:
|
61 |
+
_target_: remfx.effects.RandomPedalboardReverb
|
62 |
+
sample_rate: ${sample_rate}
|
63 |
+
min_room_size: 0.5
|
64 |
+
max_room_size: 0.5
|
65 |
+
min_damping: 0.5
|
66 |
+
max_damping: 0.5
|
67 |
+
min_wet_dry: 0.4
|
68 |
+
max_wet_dry: 0.4
|
69 |
+
min_width: 0.5
|
70 |
+
max_width: 0.5
|
cfg/exp/demucs_all.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: demucs
|
4 |
+
- override /effects: all
|
cfg/exp/umx_all.yaml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# @package _global_
|
2 |
+
defaults:
|
3 |
+
- override /model: umx
|
4 |
+
- override /effects: all
|
remfx/datasets.py
CHANGED
@@ -1,179 +1,16 @@
|
|
1 |
import torch
|
2 |
-
from torch.utils.data import Dataset, DataLoader
|
3 |
import torchaudio
|
4 |
-
import torchaudio.transforms as T
|
5 |
import torch.nn.functional as F
|
6 |
from pathlib import Path
|
7 |
import pytorch_lightning as pl
|
8 |
-
from typing import Any, List
|
9 |
from remfx import effects
|
10 |
-
from pedalboard import (
|
11 |
-
Pedalboard,
|
12 |
-
Chorus,
|
13 |
-
Reverb,
|
14 |
-
Compressor,
|
15 |
-
Phaser,
|
16 |
-
Delay,
|
17 |
-
Distortion,
|
18 |
-
Limiter,
|
19 |
-
)
|
20 |
from tqdm import tqdm
|
|
|
21 |
|
22 |
-
# https://zenodo.org/record/7044411/ -> GuitarFX
|
23 |
-
# https://zenodo.org/record/3371780 -> GuitarSet
|
24 |
# https://zenodo.org/record/1193957 -> VocalSet
|
25 |
|
26 |
-
deterministic_effects = {
|
27 |
-
"Distortion": Pedalboard([Distortion()]),
|
28 |
-
"Compressor": Pedalboard([Compressor()]),
|
29 |
-
"Chorus": Pedalboard([Chorus()]),
|
30 |
-
"Phaser": Pedalboard([Phaser()]),
|
31 |
-
"Delay": Pedalboard([Delay()]),
|
32 |
-
"Reverb": Pedalboard([Reverb()]),
|
33 |
-
"Limiter": Pedalboard([Limiter()]),
|
34 |
-
}
|
35 |
-
|
36 |
-
|
37 |
-
class GuitarFXDataset(Dataset):
|
38 |
-
def __init__(
|
39 |
-
self,
|
40 |
-
root: str,
|
41 |
-
sample_rate: int,
|
42 |
-
chunk_size_in_sec: int = 3,
|
43 |
-
effect_types: List[str] = None,
|
44 |
-
):
|
45 |
-
super().__init__()
|
46 |
-
self.wet_files = []
|
47 |
-
self.dry_files = []
|
48 |
-
self.chunks = []
|
49 |
-
self.labels = []
|
50 |
-
self.song_idx = []
|
51 |
-
self.root = Path(root)
|
52 |
-
self.chunk_size_in_sec = chunk_size_in_sec
|
53 |
-
self.sample_rate = sample_rate
|
54 |
-
|
55 |
-
if effect_types is None:
|
56 |
-
effect_types = [
|
57 |
-
d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
|
58 |
-
]
|
59 |
-
current_file = 0
|
60 |
-
for i, effect in enumerate(effect_types):
|
61 |
-
for pickup in Path(self.root / effect).iterdir():
|
62 |
-
wet_files = sorted(list(pickup.glob("*.wav")))
|
63 |
-
dry_files = sorted(
|
64 |
-
list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
|
65 |
-
)
|
66 |
-
self.wet_files += wet_files
|
67 |
-
self.dry_files += dry_files
|
68 |
-
self.labels += [i] * len(wet_files)
|
69 |
-
for audio_file in wet_files:
|
70 |
-
chunk_starts, orig_sr = create_sequential_chunks(
|
71 |
-
audio_file, self.chunk_size_in_sec
|
72 |
-
)
|
73 |
-
self.chunks += chunk_starts
|
74 |
-
self.song_idx += [current_file] * len(chunk_starts)
|
75 |
-
current_file += 1
|
76 |
-
print(
|
77 |
-
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
78 |
-
f"Total chunks: {len(self.chunks)}"
|
79 |
-
)
|
80 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
81 |
-
|
82 |
-
def __len__(self):
|
83 |
-
return len(self.chunks)
|
84 |
-
|
85 |
-
def __getitem__(self, idx):
|
86 |
-
# Load effected and "clean" audio
|
87 |
-
song_idx = self.song_idx[idx]
|
88 |
-
x, sr = torchaudio.load(self.wet_files[song_idx])
|
89 |
-
y, sr = torchaudio.load(self.dry_files[song_idx])
|
90 |
-
effect_label = self.labels[song_idx] # Effect label
|
91 |
-
|
92 |
-
chunk_start = self.chunks[idx]
|
93 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
94 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
95 |
-
y = y[:, chunk_start : chunk_start + chunk_size_in_samples]
|
96 |
-
|
97 |
-
resampled_x = self.resampler(x)
|
98 |
-
resampled_y = self.resampler(y)
|
99 |
-
# Reset chunk size to be new sample rate
|
100 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
101 |
-
# Pad to chunk_size if needed
|
102 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
103 |
-
resampled_x = F.pad(
|
104 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
105 |
-
)
|
106 |
-
if resampled_y.shape[-1] < chunk_size_in_samples:
|
107 |
-
resampled_y = F.pad(
|
108 |
-
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
|
109 |
-
)
|
110 |
-
return (resampled_x, resampled_y, effect_label)
|
111 |
-
|
112 |
-
|
113 |
-
class GuitarSet(Dataset):
|
114 |
-
def __init__(
|
115 |
-
self,
|
116 |
-
root: str,
|
117 |
-
sample_rate: int,
|
118 |
-
chunk_size_in_sec: int = 3,
|
119 |
-
effect_types: List[torch.nn.Module] = None,
|
120 |
-
):
|
121 |
-
super().__init__()
|
122 |
-
self.chunks = []
|
123 |
-
self.song_idx = []
|
124 |
-
self.root = Path(root)
|
125 |
-
self.chunk_size_in_sec = chunk_size_in_sec
|
126 |
-
self.files = sorted(list(self.root.glob("./**/*.wav")))
|
127 |
-
self.sample_rate = sample_rate
|
128 |
-
for i, audio_file in enumerate(self.files):
|
129 |
-
chunk_starts, orig_sr = create_sequential_chunks(
|
130 |
-
audio_file, self.chunk_size_in_sec
|
131 |
-
)
|
132 |
-
self.chunks += chunk_starts
|
133 |
-
self.song_idx += [i] * len(chunk_starts)
|
134 |
-
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
135 |
-
self.resampler = T.Resample(orig_sr, sample_rate)
|
136 |
-
self.effect_types = effect_types
|
137 |
-
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
138 |
-
self.mode = "train"
|
139 |
-
|
140 |
-
def __len__(self):
|
141 |
-
return len(self.chunks)
|
142 |
-
|
143 |
-
def __getitem__(self, idx):
|
144 |
-
# Load and effect audio
|
145 |
-
song_idx = self.song_idx[idx]
|
146 |
-
x, sr = torchaudio.load(self.files[song_idx])
|
147 |
-
chunk_start = self.chunks[idx]
|
148 |
-
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
149 |
-
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
150 |
-
resampled_x = self.resampler(x)
|
151 |
-
# Reset chunk size to be new sample rate
|
152 |
-
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
153 |
-
# Pad to chunk_size if needed
|
154 |
-
if resampled_x.shape[-1] < chunk_size_in_samples:
|
155 |
-
resampled_x = F.pad(
|
156 |
-
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
157 |
-
)
|
158 |
-
|
159 |
-
# Add random effect if train
|
160 |
-
if self.mode == "train":
|
161 |
-
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
162 |
-
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
163 |
-
effect = self.effect_types[effect_name]
|
164 |
-
effected_input = effect(resampled_x)
|
165 |
-
else:
|
166 |
-
# deterministic static effect for eval
|
167 |
-
effect_idx = idx % len(self.effect_types.keys())
|
168 |
-
effect_name = list(self.effect_types.keys())[effect_idx]
|
169 |
-
effect = deterministic_effects[effect_name]
|
170 |
-
effected_input = torch.from_numpy(
|
171 |
-
effect(resampled_x.numpy(), self.sample_rate)
|
172 |
-
)
|
173 |
-
normalized_input = self.normalize(effected_input)
|
174 |
-
normalized_target = self.normalize(resampled_x)
|
175 |
-
return (normalized_input, normalized_target, effect_name)
|
176 |
-
|
177 |
|
178 |
class VocalSet(Dataset):
|
179 |
def __init__(
|
@@ -183,7 +20,6 @@ class VocalSet(Dataset):
|
|
183 |
chunk_size_in_sec: int = 3,
|
184 |
effect_types: List[torch.nn.Module] = None,
|
185 |
render_files: bool = True,
|
186 |
-
output_root: str = "processed",
|
187 |
mode: str = "train",
|
188 |
):
|
189 |
super().__init__()
|
@@ -199,14 +35,15 @@ class VocalSet(Dataset):
|
|
199 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
200 |
self.effect_types = effect_types
|
201 |
|
202 |
-
self.
|
203 |
|
204 |
self.num_chunks = 0
|
205 |
print("Total files:", len(self.files))
|
206 |
print("Processing files...")
|
207 |
if render_files:
|
208 |
-
|
209 |
-
|
|
|
210 |
chunks, orig_sr = create_sequential_chunks(
|
211 |
audio_file, self.chunk_size_in_sec
|
212 |
)
|
@@ -220,14 +57,16 @@ class VocalSet(Dataset):
|
|
220 |
resampled_chunk,
|
221 |
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
222 |
)
|
|
|
223 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
224 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
225 |
effect = self.effect_types[effect_name]
|
226 |
effected_input = effect(resampled_chunk)
|
|
|
227 |
normalized_input = self.normalize(effected_input)
|
228 |
normalized_target = self.normalize(resampled_chunk)
|
229 |
|
230 |
-
output_dir = self.
|
231 |
output_dir.mkdir(exist_ok=True)
|
232 |
torchaudio.save(
|
233 |
output_dir / "input.wav", normalized_input, self.sample_rate
|
@@ -235,9 +74,10 @@ class VocalSet(Dataset):
|
|
235 |
torchaudio.save(
|
236 |
output_dir / "target.wav", normalized_target, self.sample_rate
|
237 |
)
|
|
|
238 |
self.num_chunks += 1
|
239 |
else:
|
240 |
-
self.num_chunks = len(list(self.
|
241 |
|
242 |
print(
|
243 |
f"Found {len(self.files)} {self.mode} files .\n"
|
@@ -248,95 +88,12 @@ class VocalSet(Dataset):
|
|
248 |
return self.num_chunks
|
249 |
|
250 |
def __getitem__(self, idx):
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
input, sr = torchaudio.load(input_file)
|
255 |
target, sr = torchaudio.load(target_file)
|
256 |
-
return (input, target,
|
257 |
-
|
258 |
-
|
259 |
-
def create_random_chunks(
|
260 |
-
audio_file: str, chunk_size: int, num_chunks: int
|
261 |
-
) -> Tuple[List[Tuple[int, int]], int]:
|
262 |
-
"""Create num_chunks random chunks of size chunk_size (seconds)
|
263 |
-
from an audio file.
|
264 |
-
Return sample_index of start of each chunk and original sr
|
265 |
-
"""
|
266 |
-
audio, sr = torchaudio.load(audio_file)
|
267 |
-
chunk_size_in_samples = chunk_size * sr
|
268 |
-
if chunk_size_in_samples >= audio.shape[-1]:
|
269 |
-
chunk_size_in_samples = audio.shape[-1] - 1
|
270 |
-
chunks = []
|
271 |
-
for i in range(num_chunks):
|
272 |
-
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
273 |
-
chunks.append(start)
|
274 |
-
return chunks, sr
|
275 |
-
|
276 |
-
|
277 |
-
def create_sequential_chunks(
|
278 |
-
audio_file: str, chunk_size: int
|
279 |
-
) -> Tuple[List[Tuple[int, int]], int]:
|
280 |
-
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
281 |
-
Return sample_index of start of each chunk and original sr
|
282 |
-
"""
|
283 |
-
chunks = []
|
284 |
-
audio, sr = torchaudio.load(audio_file)
|
285 |
-
chunk_size_in_samples = chunk_size * sr
|
286 |
-
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
287 |
-
for start in chunk_starts:
|
288 |
-
if start + chunk_size_in_samples > audio.shape[-1]:
|
289 |
-
break
|
290 |
-
chunks.append(audio[:, start : start + chunk_size_in_samples])
|
291 |
-
return chunks, sr
|
292 |
-
|
293 |
-
|
294 |
-
class Datamodule(pl.LightningDataModule):
|
295 |
-
def __init__(
|
296 |
-
self,
|
297 |
-
dataset,
|
298 |
-
*,
|
299 |
-
val_split: float,
|
300 |
-
batch_size: int,
|
301 |
-
num_workers: int,
|
302 |
-
pin_memory: bool = False,
|
303 |
-
**kwargs: int,
|
304 |
-
) -> None:
|
305 |
-
super().__init__()
|
306 |
-
self.dataset = dataset
|
307 |
-
self.val_split = val_split
|
308 |
-
self.batch_size = batch_size
|
309 |
-
self.num_workers = num_workers
|
310 |
-
self.pin_memory = pin_memory
|
311 |
-
self.data_train: Any = None
|
312 |
-
self.data_val: Any = None
|
313 |
-
|
314 |
-
def setup(self, stage: Any = None) -> None:
|
315 |
-
split = [1.0 - self.val_split, self.val_split]
|
316 |
-
train_size = round(split[0] * len(self.dataset))
|
317 |
-
val_size = round(split[1] * len(self.dataset))
|
318 |
-
self.data_train, self.data_val = random_split(
|
319 |
-
self.dataset, [train_size, val_size]
|
320 |
-
)
|
321 |
-
self.data_val.dataset.mode = "val"
|
322 |
-
|
323 |
-
def train_dataloader(self) -> DataLoader:
|
324 |
-
return DataLoader(
|
325 |
-
dataset=self.data_train,
|
326 |
-
batch_size=self.batch_size,
|
327 |
-
num_workers=self.num_workers,
|
328 |
-
pin_memory=self.pin_memory,
|
329 |
-
shuffle=True,
|
330 |
-
)
|
331 |
-
|
332 |
-
def val_dataloader(self) -> DataLoader:
|
333 |
-
return DataLoader(
|
334 |
-
dataset=self.data_val,
|
335 |
-
batch_size=self.batch_size,
|
336 |
-
num_workers=self.num_workers,
|
337 |
-
pin_memory=self.pin_memory,
|
338 |
-
shuffle=False,
|
339 |
-
)
|
340 |
|
341 |
|
342 |
class VocalSetDatamodule(pl.LightningDataModule):
|
@@ -344,6 +101,7 @@ class VocalSetDatamodule(pl.LightningDataModule):
|
|
344 |
self,
|
345 |
train_dataset,
|
346 |
val_dataset,
|
|
|
347 |
*,
|
348 |
batch_size: int,
|
349 |
num_workers: int,
|
@@ -353,6 +111,7 @@ class VocalSetDatamodule(pl.LightningDataModule):
|
|
353 |
super().__init__()
|
354 |
self.train_dataset = train_dataset
|
355 |
self.val_dataset = val_dataset
|
|
|
356 |
self.batch_size = batch_size
|
357 |
self.num_workers = num_workers
|
358 |
self.pin_memory = pin_memory
|
@@ -377,3 +136,12 @@ class VocalSetDatamodule(pl.LightningDataModule):
|
|
377 |
pin_memory=self.pin_memory,
|
378 |
shuffle=False,
|
379 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
import torchaudio
|
|
|
4 |
import torch.nn.functional as F
|
5 |
from pathlib import Path
|
6 |
import pytorch_lightning as pl
|
7 |
+
from typing import Any, List
|
8 |
from remfx import effects
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from tqdm import tqdm
|
10 |
+
from remfx.utils import create_sequential_chunks
|
11 |
|
|
|
|
|
12 |
# https://zenodo.org/record/1193957 -> VocalSet
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
class VocalSet(Dataset):
|
16 |
def __init__(
|
|
|
20 |
chunk_size_in_sec: int = 3,
|
21 |
effect_types: List[torch.nn.Module] = None,
|
22 |
render_files: bool = True,
|
|
|
23 |
mode: str = "train",
|
24 |
):
|
25 |
super().__init__()
|
|
|
35 |
self.normalize = effects.LoudnessNormalize(sample_rate, target_lufs_db=-20)
|
36 |
self.effect_types = effect_types
|
37 |
|
38 |
+
self.processed_root = self.root / "processed" / self.mode
|
39 |
|
40 |
self.num_chunks = 0
|
41 |
print("Total files:", len(self.files))
|
42 |
print("Processing files...")
|
43 |
if render_files:
|
44 |
+
# Split audio file into chunks, resample, then apply random effects
|
45 |
+
self.processed_root.mkdir(parents=True, exist_ok=True)
|
46 |
+
for audio_file in tqdm(self.files, total=len(self.files)):
|
47 |
chunks, orig_sr = create_sequential_chunks(
|
48 |
audio_file, self.chunk_size_in_sec
|
49 |
)
|
|
|
57 |
resampled_chunk,
|
58 |
(0, chunk_size_in_samples - resampled_chunk.shape[1]),
|
59 |
)
|
60 |
+
# Apply effect
|
61 |
effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
62 |
effect_name = list(self.effect_types.keys())[int(effect_idx)]
|
63 |
effect = self.effect_types[effect_name]
|
64 |
effected_input = effect(resampled_chunk)
|
65 |
+
# Normalize
|
66 |
normalized_input = self.normalize(effected_input)
|
67 |
normalized_target = self.normalize(resampled_chunk)
|
68 |
|
69 |
+
output_dir = self.processed_root / str(self.num_chunks)
|
70 |
output_dir.mkdir(exist_ok=True)
|
71 |
torchaudio.save(
|
72 |
output_dir / "input.wav", normalized_input, self.sample_rate
|
|
|
74 |
torchaudio.save(
|
75 |
output_dir / "target.wav", normalized_target, self.sample_rate
|
76 |
)
|
77 |
+
torch.save(effect_name, output_dir / "effect_name.pt")
|
78 |
self.num_chunks += 1
|
79 |
else:
|
80 |
+
self.num_chunks = len(list(self.processed_root.iterdir()))
|
81 |
|
82 |
print(
|
83 |
f"Found {len(self.files)} {self.mode} files .\n"
|
|
|
88 |
return self.num_chunks
|
89 |
|
90 |
def __getitem__(self, idx):
|
91 |
+
input_file = self.processed_root / str(idx) / "input.wav"
|
92 |
+
target_file = self.processed_root / str(idx) / "target.wav"
|
93 |
+
effect_name = torch.load(self.processed_root / str(idx) / "effect_name.pt")
|
94 |
input, sr = torchaudio.load(input_file)
|
95 |
target, sr = torchaudio.load(target_file)
|
96 |
+
return (input, target, effect_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
|
99 |
class VocalSetDatamodule(pl.LightningDataModule):
|
|
|
101 |
self,
|
102 |
train_dataset,
|
103 |
val_dataset,
|
104 |
+
test_dataset,
|
105 |
*,
|
106 |
batch_size: int,
|
107 |
num_workers: int,
|
|
|
111 |
super().__init__()
|
112 |
self.train_dataset = train_dataset
|
113 |
self.val_dataset = val_dataset
|
114 |
+
self.test_dataset = test_dataset
|
115 |
self.batch_size = batch_size
|
116 |
self.num_workers = num_workers
|
117 |
self.pin_memory = pin_memory
|
|
|
136 |
pin_memory=self.pin_memory,
|
137 |
shuffle=False,
|
138 |
)
|
139 |
+
|
140 |
+
def test_dataloader(self) -> DataLoader:
|
141 |
+
return DataLoader(
|
142 |
+
dataset=self.test_dataset,
|
143 |
+
batch_size=self.batch_size,
|
144 |
+
num_workers=self.num_workers,
|
145 |
+
pin_memory=self.pin_memory,
|
146 |
+
shuffle=False,
|
147 |
+
)
|
remfx/models.py
CHANGED
@@ -7,44 +7,12 @@ from audio_diffusion_pytorch import DiffusionModel
|
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
|
9 |
from torch.nn import L1Loss
|
10 |
-
from
|
11 |
-
import numpy as np
|
12 |
|
13 |
from umx.openunmix.model import OpenUnmix, Separator
|
14 |
from torchaudio.models import HDemucs
|
15 |
|
16 |
|
17 |
-
class FADLoss(torch.nn.Module):
|
18 |
-
def __init__(self, sample_rate: float):
|
19 |
-
super().__init__()
|
20 |
-
self.fad = FrechetAudioDistance(
|
21 |
-
use_pca=False, use_activation=False, verbose=False
|
22 |
-
)
|
23 |
-
self.fad.model = self.fad.model.to("cpu")
|
24 |
-
self.sr = sample_rate
|
25 |
-
|
26 |
-
def forward(self, audio_background, audio_eval):
|
27 |
-
embds_background = []
|
28 |
-
embds_eval = []
|
29 |
-
for sample in audio_background:
|
30 |
-
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
31 |
-
embds_background.append(embd.cpu().detach().numpy())
|
32 |
-
for sample in audio_eval:
|
33 |
-
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
34 |
-
embds_eval.append(embd.cpu().detach().numpy())
|
35 |
-
embds_background = np.concatenate(embds_background, axis=0)
|
36 |
-
embds_eval = np.concatenate(embds_eval, axis=0)
|
37 |
-
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
38 |
-
embds_background
|
39 |
-
)
|
40 |
-
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
41 |
-
|
42 |
-
fad_score = self.fad.calculate_frechet_distance(
|
43 |
-
mu_background, sigma_background, mu_eval, sigma_eval
|
44 |
-
)
|
45 |
-
return fad_score
|
46 |
-
|
47 |
-
|
48 |
class RemFXModel(pl.LightningModule):
|
49 |
def __init__(
|
50 |
self,
|
@@ -97,6 +65,10 @@ class RemFXModel(pl.LightningModule):
|
|
97 |
loss = self.common_step(batch, batch_idx, mode="valid")
|
98 |
return loss
|
99 |
|
|
|
|
|
|
|
|
|
100 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
101 |
loss, output = self.model(batch)
|
102 |
self.log(f"{mode}_loss", loss)
|
@@ -121,6 +93,7 @@ class RemFXModel(pl.LightningModule):
|
|
121 |
return loss
|
122 |
|
123 |
def on_train_batch_start(self, batch, batch_idx):
|
|
|
124 |
if self.log_train_audio:
|
125 |
x, y, label = batch
|
126 |
# Concat samples together for easier viewing in dashboard
|
@@ -143,48 +116,47 @@ class RemFXModel(pl.LightningModule):
|
|
143 |
)
|
144 |
self.log_train_audio = False
|
145 |
|
146 |
-
def on_validation_epoch_start(self):
|
147 |
-
self.log_next = True
|
148 |
-
|
149 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
)
|
168 |
-
|
169 |
-
self.model.eval()
|
170 |
-
with torch.no_grad():
|
171 |
-
y = self.model.sample(x)
|
172 |
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
-
|
179 |
-
|
180 |
-
logger=self.logger,
|
181 |
-
id="prediction_input_target",
|
182 |
-
samples=concat_samples.cpu(),
|
183 |
-
sampling_rate=self.sample_rate,
|
184 |
-
caption=f"Epoch {self.current_epoch}",
|
185 |
-
)
|
186 |
-
self.log_next = False
|
187 |
-
self.model.train()
|
188 |
|
189 |
|
190 |
class OpenUnmixModel(torch.nn.Module):
|
|
|
7 |
from auraloss.time import SISDRLoss
|
8 |
from auraloss.freq import MultiResolutionSTFTLoss, STFTLoss
|
9 |
from torch.nn import L1Loss
|
10 |
+
from remfx.utils import FADLoss
|
|
|
11 |
|
12 |
from umx.openunmix.model import OpenUnmix, Separator
|
13 |
from torchaudio.models import HDemucs
|
14 |
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
class RemFXModel(pl.LightningModule):
|
17 |
def __init__(
|
18 |
self,
|
|
|
65 |
loss = self.common_step(batch, batch_idx, mode="valid")
|
66 |
return loss
|
67 |
|
68 |
+
def test_step(self, batch, batch_idx):
|
69 |
+
loss = self.common_step(batch, batch_idx, mode="test")
|
70 |
+
return loss
|
71 |
+
|
72 |
def common_step(self, batch, batch_idx, mode: str = "train"):
|
73 |
loss, output = self.model(batch)
|
74 |
self.log(f"{mode}_loss", loss)
|
|
|
93 |
return loss
|
94 |
|
95 |
def on_train_batch_start(self, batch, batch_idx):
|
96 |
+
# Log initial audio
|
97 |
if self.log_train_audio:
|
98 |
x, y, label = batch
|
99 |
# Concat samples together for easier viewing in dashboard
|
|
|
116 |
)
|
117 |
self.log_train_audio = False
|
118 |
|
|
|
|
|
|
|
119 |
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
|
120 |
+
x, target, label = batch
|
121 |
+
# Log Input Metrics
|
122 |
+
for metric in self.metrics:
|
123 |
+
# SISDR returns negative values, so negate them
|
124 |
+
if metric == "SISDR":
|
125 |
+
negate = -1
|
126 |
+
else:
|
127 |
+
negate = 1
|
128 |
+
self.log(
|
129 |
+
f"Input_{metric}",
|
130 |
+
negate * self.metrics[metric](x, target),
|
131 |
+
on_step=False,
|
132 |
+
on_epoch=True,
|
133 |
+
logger=True,
|
134 |
+
prog_bar=True,
|
135 |
+
sync_dist=True,
|
136 |
+
)
|
|
|
|
|
|
|
|
|
|
|
137 |
|
138 |
+
self.model.eval()
|
139 |
+
with torch.no_grad():
|
140 |
+
y = self.model.sample(x)
|
141 |
+
|
142 |
+
# Concat samples together for easier viewing in dashboard
|
143 |
+
# 2 seconds of silence between each sample
|
144 |
+
silence = torch.zeros_like(x)
|
145 |
+
silence = silence[:, : self.sample_rate * 2]
|
146 |
+
|
147 |
+
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
|
148 |
+
log_wandb_audio_batch(
|
149 |
+
logger=self.logger,
|
150 |
+
id="prediction_input_target",
|
151 |
+
samples=concat_samples.cpu(),
|
152 |
+
sampling_rate=self.sample_rate,
|
153 |
+
caption=f"Epoch {self.current_epoch}",
|
154 |
+
)
|
155 |
+
self.log_next = False
|
156 |
+
self.model.train()
|
157 |
|
158 |
+
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
|
159 |
+
return self.on_validation_batch_start(batch, batch_idx, dataloader_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
|
161 |
|
162 |
class OpenUnmixModel(torch.nn.Module):
|
remfx/utils.py
CHANGED
@@ -1,8 +1,12 @@
|
|
1 |
import logging
|
2 |
-
from typing import List
|
3 |
import pytorch_lightning as pl
|
4 |
from omegaconf import DictConfig
|
5 |
from pytorch_lightning.utilities import rank_zero_only
|
|
|
|
|
|
|
|
|
6 |
|
7 |
|
8 |
def get_logger(name=__name__) -> logging.Logger:
|
@@ -69,3 +73,69 @@ def log_hyperparameters(
|
|
69 |
hparams["callbacks"] = config["callbacks"]
|
70 |
|
71 |
logger.experiment.config.update(hparams)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import logging
|
2 |
+
from typing import List, Tuple
|
3 |
import pytorch_lightning as pl
|
4 |
from omegaconf import DictConfig
|
5 |
from pytorch_lightning.utilities import rank_zero_only
|
6 |
+
from frechet_audio_distance import FrechetAudioDistance
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torchaudio
|
10 |
|
11 |
|
12 |
def get_logger(name=__name__) -> logging.Logger:
|
|
|
73 |
hparams["callbacks"] = config["callbacks"]
|
74 |
|
75 |
logger.experiment.config.update(hparams)
|
76 |
+
|
77 |
+
|
78 |
+
class FADLoss(torch.nn.Module):
|
79 |
+
def __init__(self, sample_rate: float):
|
80 |
+
super().__init__()
|
81 |
+
self.fad = FrechetAudioDistance(
|
82 |
+
use_pca=False, use_activation=False, verbose=False
|
83 |
+
)
|
84 |
+
self.fad.model = self.fad.model.to("cpu")
|
85 |
+
self.sr = sample_rate
|
86 |
+
|
87 |
+
def forward(self, audio_background, audio_eval):
|
88 |
+
embds_background = []
|
89 |
+
embds_eval = []
|
90 |
+
for sample in audio_background:
|
91 |
+
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
92 |
+
embds_background.append(embd.cpu().detach().numpy())
|
93 |
+
for sample in audio_eval:
|
94 |
+
embd = self.fad.model.forward(sample.T.cpu().detach().numpy(), self.sr)
|
95 |
+
embds_eval.append(embd.cpu().detach().numpy())
|
96 |
+
embds_background = np.concatenate(embds_background, axis=0)
|
97 |
+
embds_eval = np.concatenate(embds_eval, axis=0)
|
98 |
+
mu_background, sigma_background = self.fad.calculate_embd_statistics(
|
99 |
+
embds_background
|
100 |
+
)
|
101 |
+
mu_eval, sigma_eval = self.fad.calculate_embd_statistics(embds_eval)
|
102 |
+
|
103 |
+
fad_score = self.fad.calculate_frechet_distance(
|
104 |
+
mu_background, sigma_background, mu_eval, sigma_eval
|
105 |
+
)
|
106 |
+
return fad_score
|
107 |
+
|
108 |
+
|
109 |
+
def create_random_chunks(
|
110 |
+
audio_file: str, chunk_size: int, num_chunks: int
|
111 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
112 |
+
"""Create num_chunks random chunks of size chunk_size (seconds)
|
113 |
+
from an audio file.
|
114 |
+
Return sample_index of start of each chunk and original sr
|
115 |
+
"""
|
116 |
+
audio, sr = torchaudio.load(audio_file)
|
117 |
+
chunk_size_in_samples = chunk_size * sr
|
118 |
+
if chunk_size_in_samples >= audio.shape[-1]:
|
119 |
+
chunk_size_in_samples = audio.shape[-1] - 1
|
120 |
+
chunks = []
|
121 |
+
for i in range(num_chunks):
|
122 |
+
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
123 |
+
chunks.append(start)
|
124 |
+
return chunks, sr
|
125 |
+
|
126 |
+
|
127 |
+
def create_sequential_chunks(
|
128 |
+
audio_file: str, chunk_size: int
|
129 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
130 |
+
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
131 |
+
Return sample_index of start of each chunk and original sr
|
132 |
+
"""
|
133 |
+
chunks = []
|
134 |
+
audio, sr = torchaudio.load(audio_file)
|
135 |
+
chunk_size_in_samples = chunk_size * sr
|
136 |
+
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
137 |
+
for start in chunk_starts:
|
138 |
+
if start + chunk_size_in_samples > audio.shape[-1]:
|
139 |
+
break
|
140 |
+
chunks.append(audio[:, start : start + chunk_size_in_samples])
|
141 |
+
return chunks, sr
|
shell_vars.sh
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
export DATASET_ROOT="./data/VocalSet"
|
2 |
-
export OUTPUT_ROOT="/scratch/VocalSet/processed"
|
3 |
export WANDB_PROJECT="RemFX"
|
4 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
export DATASET_ROOT="./data/VocalSet"
|
|
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|