Spaces:
Runtime error
Runtime error
mattricesound
commited on
Commit
•
b175ee9
1
Parent(s):
dfbeb31
Add new GuitarSet dataset. Add pedal board effects. Fix sample_rate mismatch bug
Browse files- config.yaml +1 -3
- config_guitfx.yaml +52 -0
- exp/umx.yaml +6 -1
- remfx/datasets.py +80 -21
- remfx/effects.py +698 -0
- remfx/models.py +5 -1
- setup.py +2 -0
- shell_vars.sh +1 -1
config.yaml
CHANGED
@@ -3,7 +3,6 @@ defaults:
|
|
3 |
- exp: null
|
4 |
seed: 12345
|
5 |
train: True
|
6 |
-
length: 262144
|
7 |
sample_rate: 48000
|
8 |
logs_dir: "./logs"
|
9 |
log_every_n_steps: 1000
|
@@ -22,10 +21,9 @@ callbacks:
|
|
22 |
datamodule:
|
23 |
_target_: remfx.datasets.Datamodule
|
24 |
dataset:
|
25 |
-
_target_: remfx.datasets.
|
26 |
sample_rate: ${sample_rate}
|
27 |
root: ${oc.env:DATASET_ROOT}
|
28 |
-
length: ${length}
|
29 |
chunk_size_in_sec: 6
|
30 |
val_split: 0.2
|
31 |
batch_size: 16
|
|
|
3 |
- exp: null
|
4 |
seed: 12345
|
5 |
train: True
|
|
|
6 |
sample_rate: 48000
|
7 |
logs_dir: "./logs"
|
8 |
log_every_n_steps: 1000
|
|
|
21 |
datamodule:
|
22 |
_target_: remfx.datasets.Datamodule
|
23 |
dataset:
|
24 |
+
_target_: remfx.datasets.GuitarSet
|
25 |
sample_rate: ${sample_rate}
|
26 |
root: ${oc.env:DATASET_ROOT}
|
|
|
27 |
chunk_size_in_sec: 6
|
28 |
val_split: 0.2
|
29 |
batch_size: 16
|
config_guitfx.yaml
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- _self_
|
3 |
+
- exp: null
|
4 |
+
seed: 12345
|
5 |
+
train: True
|
6 |
+
sample_rate: 48000
|
7 |
+
logs_dir: "./logs"
|
8 |
+
log_every_n_steps: 1000
|
9 |
+
|
10 |
+
callbacks:
|
11 |
+
model_checkpoint:
|
12 |
+
_target_: pytorch_lightning.callbacks.ModelCheckpoint
|
13 |
+
monitor: "valid_loss" # name of the logged metric which determines when model is improving
|
14 |
+
save_top_k: 1 # save k best models (determined by above metric)
|
15 |
+
save_last: True # additionaly always save model from last epoch
|
16 |
+
mode: "min" # can be "max" or "min"
|
17 |
+
verbose: False
|
18 |
+
dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
|
19 |
+
filename: '{epoch:02d}-{valid_loss:.3f}'
|
20 |
+
|
21 |
+
datamodule:
|
22 |
+
_target_: remfx.datasets.Datamodule
|
23 |
+
dataset:
|
24 |
+
_target_: remfx.datasets.GuitarFXDataset
|
25 |
+
sample_rate: ${sample_rate}
|
26 |
+
root: ${oc.env:DATASET_ROOT}
|
27 |
+
chunk_size_in_sec: 6
|
28 |
+
val_split: 0.2
|
29 |
+
batch_size: 16
|
30 |
+
num_workers: 8
|
31 |
+
pin_memory: True
|
32 |
+
persistent_workers: True
|
33 |
+
|
34 |
+
logger:
|
35 |
+
_target_: pytorch_lightning.loggers.WandbLogger
|
36 |
+
project: ${oc.env:WANDB_PROJECT}
|
37 |
+
entity: ${oc.env:WANDB_ENTITY}
|
38 |
+
# offline: False # set True to store all logs only locally
|
39 |
+
job_type: "train"
|
40 |
+
group: ""
|
41 |
+
save_dir: "."
|
42 |
+
|
43 |
+
trainer:
|
44 |
+
_target_: pytorch_lightning.Trainer
|
45 |
+
precision: 32 # Precision used for tensors, default `32`
|
46 |
+
min_epochs: 0
|
47 |
+
max_epochs: -1
|
48 |
+
enable_model_summary: False
|
49 |
+
log_every_n_steps: 1 # Logs metrics every N batches
|
50 |
+
accumulate_grad_batches: 1
|
51 |
+
accelerator: null
|
52 |
+
devices: 1
|
exp/umx.yaml
CHANGED
@@ -16,4 +16,9 @@ model:
|
|
16 |
sample_rate: ${sample_rate}
|
17 |
datamodule:
|
18 |
dataset:
|
19 |
-
effect_types:
|
|
|
|
|
|
|
|
|
|
|
|
16 |
sample_rate: ${sample_rate}
|
17 |
datamodule:
|
18 |
dataset:
|
19 |
+
effect_types:
|
20 |
+
Distortion:
|
21 |
+
_target_: remfx.effects.RandomPedalboardDistortion
|
22 |
+
sample_rate: ${sample_rate}
|
23 |
+
min_drive_db: -10
|
24 |
+
max_drive_db: 30
|
remfx/datasets.py
CHANGED
@@ -7,10 +7,8 @@ from pathlib import Path
|
|
7 |
import pytorch_lightning as pl
|
8 |
from typing import Any, List, Tuple
|
9 |
|
10 |
-
# https://zenodo.org/record/7044411/
|
11 |
-
|
12 |
-
LENGTH = 2**18 # 12 seconds
|
13 |
-
ORIG_SR = 48000
|
14 |
|
15 |
|
16 |
class GuitarFXDataset(Dataset):
|
@@ -18,11 +16,10 @@ class GuitarFXDataset(Dataset):
|
|
18 |
self,
|
19 |
root: str,
|
20 |
sample_rate: int,
|
21 |
-
length: int = LENGTH,
|
22 |
chunk_size_in_sec: int = 3,
|
23 |
effect_types: List[str] = None,
|
24 |
):
|
25 |
-
|
26 |
self.wet_files = []
|
27 |
self.dry_files = []
|
28 |
self.chunks = []
|
@@ -30,6 +27,7 @@ class GuitarFXDataset(Dataset):
|
|
30 |
self.song_idx = []
|
31 |
self.root = Path(root)
|
32 |
self.chunk_size_in_sec = chunk_size_in_sec
|
|
|
33 |
|
34 |
if effect_types is None:
|
35 |
effect_types = [
|
@@ -46,7 +44,7 @@ class GuitarFXDataset(Dataset):
|
|
46 |
self.dry_files += dry_files
|
47 |
self.labels += [i] * len(wet_files)
|
48 |
for audio_file in wet_files:
|
49 |
-
chunk_starts = create_sequential_chunks(
|
50 |
audio_file, self.chunk_size_in_sec
|
51 |
)
|
52 |
self.chunks += chunk_starts
|
@@ -56,7 +54,7 @@ class GuitarFXDataset(Dataset):
|
|
56 |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
57 |
f"Total chunks: {len(self.chunks)}"
|
58 |
)
|
59 |
-
self.resampler = T.Resample(
|
60 |
|
61 |
def __len__(self):
|
62 |
return len(self.chunks)
|
@@ -75,20 +73,79 @@ class GuitarFXDataset(Dataset):
|
|
75 |
|
76 |
resampled_x = self.resampler(x)
|
77 |
resampled_y = self.resampler(y)
|
78 |
-
#
|
79 |
-
|
80 |
-
|
81 |
-
if
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
return (resampled_x, resampled_y, effect_label)
|
84 |
|
85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
def create_random_chunks(
|
87 |
audio_file: str, chunk_size: int, num_chunks: int
|
88 |
-
) -> List[Tuple[int, int]]:
|
89 |
"""Create num_chunks random chunks of size chunk_size (seconds)
|
90 |
from an audio file.
|
91 |
-
Return sample_index of start of each chunk
|
92 |
"""
|
93 |
audio, sr = torchaudio.load(audio_file)
|
94 |
chunk_size_in_samples = chunk_size * sr
|
@@ -98,17 +155,19 @@ def create_random_chunks(
|
|
98 |
for i in range(num_chunks):
|
99 |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
100 |
chunks.append(start)
|
101 |
-
return chunks
|
102 |
|
103 |
|
104 |
-
def create_sequential_chunks(
|
|
|
|
|
105 |
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
106 |
-
Return sample_index of start of each chunk
|
107 |
"""
|
108 |
audio, sr = torchaudio.load(audio_file)
|
109 |
chunk_size_in_samples = chunk_size * sr
|
110 |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
111 |
-
return chunk_starts
|
112 |
|
113 |
|
114 |
class Datamodule(pl.LightningDataModule):
|
@@ -133,8 +192,8 @@ class Datamodule(pl.LightningDataModule):
|
|
133 |
|
134 |
def setup(self, stage: Any = None) -> None:
|
135 |
split = [1.0 - self.val_split, self.val_split]
|
136 |
-
train_size =
|
137 |
-
val_size =
|
138 |
self.data_train, self.data_val = random_split(
|
139 |
self.dataset, [train_size, val_size]
|
140 |
)
|
|
|
7 |
import pytorch_lightning as pl
|
8 |
from typing import Any, List, Tuple
|
9 |
|
10 |
+
# https://zenodo.org/record/7044411/ -> GuitarFX
|
11 |
+
# https://zenodo.org/record/3371780 -> GuitarSet
|
|
|
|
|
12 |
|
13 |
|
14 |
class GuitarFXDataset(Dataset):
|
|
|
16 |
self,
|
17 |
root: str,
|
18 |
sample_rate: int,
|
|
|
19 |
chunk_size_in_sec: int = 3,
|
20 |
effect_types: List[str] = None,
|
21 |
):
|
22 |
+
super().__init__()
|
23 |
self.wet_files = []
|
24 |
self.dry_files = []
|
25 |
self.chunks = []
|
|
|
27 |
self.song_idx = []
|
28 |
self.root = Path(root)
|
29 |
self.chunk_size_in_sec = chunk_size_in_sec
|
30 |
+
self.sample_rate = sample_rate
|
31 |
|
32 |
if effect_types is None:
|
33 |
effect_types = [
|
|
|
44 |
self.dry_files += dry_files
|
45 |
self.labels += [i] * len(wet_files)
|
46 |
for audio_file in wet_files:
|
47 |
+
chunk_starts, orig_sr = create_sequential_chunks(
|
48 |
audio_file, self.chunk_size_in_sec
|
49 |
)
|
50 |
self.chunks += chunk_starts
|
|
|
54 |
f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
|
55 |
f"Total chunks: {len(self.chunks)}"
|
56 |
)
|
57 |
+
self.resampler = T.Resample(orig_sr, sample_rate)
|
58 |
|
59 |
def __len__(self):
|
60 |
return len(self.chunks)
|
|
|
73 |
|
74 |
resampled_x = self.resampler(x)
|
75 |
resampled_y = self.resampler(y)
|
76 |
+
# Reset chunk size to be new sample rate
|
77 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
78 |
+
# Pad to chunk_size if needed
|
79 |
+
if resampled_x.shape[-1] < chunk_size_in_samples:
|
80 |
+
resampled_x = F.pad(
|
81 |
+
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
82 |
+
)
|
83 |
+
if resampled_y.shape[-1] < chunk_size_in_samples:
|
84 |
+
resampled_y = F.pad(
|
85 |
+
resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
|
86 |
+
)
|
87 |
return (resampled_x, resampled_y, effect_label)
|
88 |
|
89 |
|
90 |
+
class GuitarSet(Dataset):
|
91 |
+
def __init__(
|
92 |
+
self,
|
93 |
+
root: str,
|
94 |
+
sample_rate: int,
|
95 |
+
chunk_size_in_sec: int = 3,
|
96 |
+
effect_types: List[torch.nn.Module] = None,
|
97 |
+
):
|
98 |
+
super().__init__()
|
99 |
+
self.chunks = []
|
100 |
+
self.song_idx = []
|
101 |
+
self.root = Path(root)
|
102 |
+
self.chunk_size_in_sec = chunk_size_in_sec
|
103 |
+
self.files = sorted(list(self.root.glob("./**/*.wav")))
|
104 |
+
self.sample_rate = sample_rate
|
105 |
+
for i, audio_file in enumerate(self.files):
|
106 |
+
chunk_starts, orig_sr = create_sequential_chunks(
|
107 |
+
audio_file, self.chunk_size_in_sec
|
108 |
+
)
|
109 |
+
self.chunks += chunk_starts
|
110 |
+
self.song_idx += [i] * len(chunk_starts)
|
111 |
+
print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
|
112 |
+
self.resampler = T.Resample(orig_sr, sample_rate)
|
113 |
+
self.effect_types = effect_types
|
114 |
+
|
115 |
+
def __len__(self):
|
116 |
+
return len(self.chunks)
|
117 |
+
|
118 |
+
def __getitem__(self, idx):
|
119 |
+
# Load and effect audio
|
120 |
+
song_idx = self.song_idx[idx]
|
121 |
+
x, sr = torchaudio.load(self.files[song_idx])
|
122 |
+
chunk_start = self.chunks[idx]
|
123 |
+
chunk_size_in_samples = self.chunk_size_in_sec * sr
|
124 |
+
x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
|
125 |
+
resampled_x = self.resampler(x)
|
126 |
+
# Reset chunk size to be new sample rate
|
127 |
+
chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
|
128 |
+
# Pad to chunk_size if needed
|
129 |
+
if resampled_x.shape[-1] < chunk_size_in_samples:
|
130 |
+
resampled_x = F.pad(
|
131 |
+
resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
|
132 |
+
)
|
133 |
+
target = resampled_x
|
134 |
+
|
135 |
+
# Add random effect
|
136 |
+
random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
|
137 |
+
effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
|
138 |
+
effect = self.effect_types[effect_name]
|
139 |
+
effected_input = effect(resampled_x)
|
140 |
+
return (effected_input, target, effect_name)
|
141 |
+
|
142 |
+
|
143 |
def create_random_chunks(
|
144 |
audio_file: str, chunk_size: int, num_chunks: int
|
145 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
146 |
"""Create num_chunks random chunks of size chunk_size (seconds)
|
147 |
from an audio file.
|
148 |
+
Return sample_index of start of each chunk and original sr
|
149 |
"""
|
150 |
audio, sr = torchaudio.load(audio_file)
|
151 |
chunk_size_in_samples = chunk_size * sr
|
|
|
155 |
for i in range(num_chunks):
|
156 |
start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
|
157 |
chunks.append(start)
|
158 |
+
return chunks, sr
|
159 |
|
160 |
|
161 |
+
def create_sequential_chunks(
|
162 |
+
audio_file: str, chunk_size: int
|
163 |
+
) -> Tuple[List[Tuple[int, int]], int]:
|
164 |
"""Create sequential chunks of size chunk_size (seconds) from an audio file.
|
165 |
+
Return sample_index of start of each chunk and original sr
|
166 |
"""
|
167 |
audio, sr = torchaudio.load(audio_file)
|
168 |
chunk_size_in_samples = chunk_size * sr
|
169 |
chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
|
170 |
+
return chunk_starts, sr
|
171 |
|
172 |
|
173 |
class Datamodule(pl.LightningDataModule):
|
|
|
192 |
|
193 |
def setup(self, stage: Any = None) -> None:
|
194 |
split = [1.0 - self.val_split, self.val_split]
|
195 |
+
train_size = round(split[0] * len(self.dataset))
|
196 |
+
val_size = round(split[1] * len(self.dataset))
|
197 |
self.data_train, self.data_val = random_split(
|
198 |
self.dataset, [train_size, val_size]
|
199 |
)
|
remfx/effects.py
ADDED
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchaudio
|
3 |
+
import numpy as np
|
4 |
+
import scipy.signal
|
5 |
+
import scipy.stats
|
6 |
+
import pyloudnorm as pyln
|
7 |
+
from torchvision.transforms import Compose, RandomApply
|
8 |
+
|
9 |
+
|
10 |
+
from typing import List
|
11 |
+
from pedalboard import (
|
12 |
+
Pedalboard,
|
13 |
+
Chorus,
|
14 |
+
Reverb,
|
15 |
+
Compressor,
|
16 |
+
Phaser,
|
17 |
+
Delay,
|
18 |
+
Distortion,
|
19 |
+
Limiter,
|
20 |
+
)
|
21 |
+
|
22 |
+
__all__ = []
|
23 |
+
|
24 |
+
|
25 |
+
def loguniform(low=0, high=1):
|
26 |
+
return scipy.stats.loguniform.rvs(low, high)
|
27 |
+
|
28 |
+
|
29 |
+
def rand(low=0, high=1):
|
30 |
+
return (torch.rand(1).numpy()[0] * (high - low)) + low
|
31 |
+
|
32 |
+
|
33 |
+
def randint(low=0, high=1):
|
34 |
+
return torch.randint(low, high + 1, (1,)).numpy()[0]
|
35 |
+
|
36 |
+
|
37 |
+
def biqaud(
|
38 |
+
gain_db: float,
|
39 |
+
cutoff_freq: float,
|
40 |
+
q_factor: float,
|
41 |
+
sample_rate: float,
|
42 |
+
filter_type: str,
|
43 |
+
):
|
44 |
+
"""Use design parameters to generate coeffieicnets for a specific filter type.
|
45 |
+
Args:
|
46 |
+
gain_db (float): Shelving filter gain in dB.
|
47 |
+
cutoff_freq (float): Cutoff frequency in Hz.
|
48 |
+
q_factor (float): Q factor.
|
49 |
+
sample_rate (float): Sample rate in Hz.
|
50 |
+
filter_type (str): Filter type.
|
51 |
+
One of "low_shelf", "high_shelf", or "peaking"
|
52 |
+
Returns:
|
53 |
+
b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2]
|
54 |
+
a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2]
|
55 |
+
"""
|
56 |
+
|
57 |
+
A = 10 ** (gain_db / 40.0)
|
58 |
+
w0 = 2.0 * np.pi * (cutoff_freq / sample_rate)
|
59 |
+
alpha = np.sin(w0) / (2.0 * q_factor)
|
60 |
+
|
61 |
+
cos_w0 = np.cos(w0)
|
62 |
+
sqrt_A = np.sqrt(A)
|
63 |
+
|
64 |
+
if filter_type == "high_shelf":
|
65 |
+
b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
66 |
+
b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
|
67 |
+
b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
68 |
+
a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
69 |
+
a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
|
70 |
+
a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
71 |
+
elif filter_type == "low_shelf":
|
72 |
+
b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
|
73 |
+
b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
|
74 |
+
b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
|
75 |
+
a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
|
76 |
+
a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
|
77 |
+
a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
|
78 |
+
elif filter_type == "peaking":
|
79 |
+
b0 = 1 + alpha * A
|
80 |
+
b1 = -2 * cos_w0
|
81 |
+
b2 = 1 - alpha * A
|
82 |
+
a0 = 1 + alpha / A
|
83 |
+
a1 = -2 * cos_w0
|
84 |
+
a2 = 1 - alpha / A
|
85 |
+
else:
|
86 |
+
pass
|
87 |
+
# raise ValueError(f"Invalid filter_type: {filter_type}.")
|
88 |
+
|
89 |
+
b = np.array([b0, b1, b2]) / a0
|
90 |
+
a = np.array([a0, a1, a2]) / a0
|
91 |
+
|
92 |
+
return b, a
|
93 |
+
|
94 |
+
|
95 |
+
def parametric_eq(
|
96 |
+
x: np.ndarray,
|
97 |
+
sample_rate: float,
|
98 |
+
low_shelf_gain_db: float = 0.0,
|
99 |
+
low_shelf_cutoff_freq: float = 80.0,
|
100 |
+
low_shelf_q_factor: float = 0.707,
|
101 |
+
band_gains_db: List[float] = [0.0],
|
102 |
+
band_cutoff_freqs: List[float] = [300.0],
|
103 |
+
band_q_factors: List[float] = [0.707],
|
104 |
+
high_shelf_gain_db: float = 0.0,
|
105 |
+
high_shelf_cutoff_freq: float = 1000.0,
|
106 |
+
high_shelf_q_factor: float = 0.707,
|
107 |
+
dtype=np.float32,
|
108 |
+
):
|
109 |
+
"""Multiband parametric EQ.
|
110 |
+
Low-shelf -> Band 1 -> ... -> Band N -> High-shelf
|
111 |
+
Args:
|
112 |
+
"""
|
113 |
+
assert (
|
114 |
+
len(band_gains_db) == len(band_cutoff_freqs) == len(band_q_factors)
|
115 |
+
) # must define for all bands
|
116 |
+
|
117 |
+
# -------- apply low-shelf filter --------
|
118 |
+
b, a = biqaud(
|
119 |
+
low_shelf_gain_db,
|
120 |
+
low_shelf_cutoff_freq,
|
121 |
+
low_shelf_q_factor,
|
122 |
+
sample_rate,
|
123 |
+
"low_shelf",
|
124 |
+
)
|
125 |
+
x = scipy.signal.lfilter(b, a, x)
|
126 |
+
|
127 |
+
# -------- apply peaking filters --------
|
128 |
+
for gain_db, cutoff_freq, q_factor in zip(
|
129 |
+
band_gains_db, band_cutoff_freqs, band_q_factors
|
130 |
+
):
|
131 |
+
b, a = biqaud(
|
132 |
+
gain_db,
|
133 |
+
cutoff_freq,
|
134 |
+
q_factor,
|
135 |
+
sample_rate,
|
136 |
+
"peaking",
|
137 |
+
)
|
138 |
+
x = scipy.signal.lfilter(b, a, x)
|
139 |
+
|
140 |
+
# -------- apply high-shelf filter --------
|
141 |
+
b, a = biqaud(
|
142 |
+
high_shelf_gain_db,
|
143 |
+
high_shelf_cutoff_freq,
|
144 |
+
high_shelf_q_factor,
|
145 |
+
sample_rate,
|
146 |
+
"high_shelf",
|
147 |
+
)
|
148 |
+
sos5 = np.concatenate((b, a))
|
149 |
+
x = scipy.signal.lfilter(b, a, x)
|
150 |
+
|
151 |
+
return x.astype(dtype)
|
152 |
+
|
153 |
+
|
154 |
+
class RandomParametricEQ(torch.nn.Module):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
sample_rate: float,
|
158 |
+
num_bands: int = 3,
|
159 |
+
min_gain_db: float = -6.0,
|
160 |
+
max_gain_db: float = +6.0,
|
161 |
+
min_cutoff_freq: float = 1000.0,
|
162 |
+
max_cutoff_freq: float = 10000.0,
|
163 |
+
min_q_factor: float = 0.1,
|
164 |
+
max_q_factor: float = 4.0,
|
165 |
+
):
|
166 |
+
super().__init__()
|
167 |
+
self.sample_rate = sample_rate
|
168 |
+
self.num_bands = num_bands
|
169 |
+
self.min_gain_db = min_gain_db
|
170 |
+
self.max_gain_db = max_gain_db
|
171 |
+
self.min_cutoff_freq = min_cutoff_freq
|
172 |
+
self.max_cutoff_freq = max_cutoff_freq
|
173 |
+
self.min_q_factor = min_q_factor
|
174 |
+
self.max_q_factor = max_q_factor
|
175 |
+
|
176 |
+
def forward(self, x: torch.Tensor):
|
177 |
+
"""
|
178 |
+
Args:
|
179 |
+
x: (torch.Tensor): Array of audio samples with shape (chs, seq_leq).
|
180 |
+
The filter will be applied the final dimension, and by default the same
|
181 |
+
filter will be applied to all channels.
|
182 |
+
"""
|
183 |
+
low_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
|
184 |
+
low_shelf_cutoff_freq = loguniform(20.0, 200.0)
|
185 |
+
low_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
|
186 |
+
|
187 |
+
high_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
|
188 |
+
high_shelf_cutoff_freq = loguniform(8000.0, 16000.0)
|
189 |
+
high_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
|
190 |
+
|
191 |
+
band_gain_dbs = []
|
192 |
+
band_cutoff_freqs = []
|
193 |
+
band_q_factors = []
|
194 |
+
for _ in range(self.num_bands):
|
195 |
+
band_gain_dbs.append(rand(self.min_gain_db, self.max_gain_db))
|
196 |
+
band_cutoff_freqs.append(
|
197 |
+
loguniform(self.min_cutoff_freq, self.max_cutoff_freq)
|
198 |
+
)
|
199 |
+
band_q_factors.append(rand(self.min_q_factor, self.max_q_factor))
|
200 |
+
|
201 |
+
y = parametric_eq(
|
202 |
+
x.numpy(),
|
203 |
+
self.sample_rate,
|
204 |
+
low_shelf_gain_db=low_shelf_gain_db,
|
205 |
+
low_shelf_cutoff_freq=low_shelf_cutoff_freq,
|
206 |
+
low_shelf_q_factor=low_shelf_q_factor,
|
207 |
+
band_gains_db=band_gain_dbs,
|
208 |
+
band_cutoff_freqs=band_cutoff_freqs,
|
209 |
+
band_q_factors=band_q_factors,
|
210 |
+
high_shelf_gain_db=high_shelf_gain_db,
|
211 |
+
high_shelf_cutoff_freq=high_shelf_cutoff_freq,
|
212 |
+
high_shelf_q_factor=high_shelf_q_factor,
|
213 |
+
)
|
214 |
+
|
215 |
+
return torch.from_numpy(y)
|
216 |
+
|
217 |
+
|
218 |
+
def stereo_widener(x: torch.Tensor, width: torch.Tensor):
|
219 |
+
sqrt2 = np.sqrt(2)
|
220 |
+
|
221 |
+
left = x[0, ...]
|
222 |
+
right = x[1, ...]
|
223 |
+
|
224 |
+
mid = (left + right) / sqrt2
|
225 |
+
side = (left - right) / sqrt2
|
226 |
+
|
227 |
+
# amplify mid and side signal seperately:
|
228 |
+
mid *= 2 * (1 - width)
|
229 |
+
side *= 2 * width
|
230 |
+
|
231 |
+
left = (mid + side) / sqrt2
|
232 |
+
right = (mid - side) / sqrt2
|
233 |
+
|
234 |
+
x = torch.stack((left, right), dim=0)
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
|
239 |
+
class RandomStereoWidener(torch.nn.Module):
|
240 |
+
def __init__(
|
241 |
+
self,
|
242 |
+
sample_rate: float,
|
243 |
+
min_width: float = 0.0,
|
244 |
+
max_width: float = 1.0,
|
245 |
+
) -> None:
|
246 |
+
super().__init__()
|
247 |
+
self.sample_rate = sample_rate
|
248 |
+
self.min_width = min_width
|
249 |
+
self.max_width = max_width
|
250 |
+
|
251 |
+
def forward(self, x: torch.Tensor):
|
252 |
+
width = rand(self.min_width, self.max_width)
|
253 |
+
return stereo_widener(x, width)
|
254 |
+
|
255 |
+
|
256 |
+
class RandomVolumeAutomation(torch.nn.Module):
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
sample_rate: float,
|
260 |
+
min_segments: int = 1,
|
261 |
+
max_segments: int = 3,
|
262 |
+
min_gain_db: float = -6.0,
|
263 |
+
max_gain_db: float = 6.0,
|
264 |
+
) -> None:
|
265 |
+
super().__init__()
|
266 |
+
self.sample_rate = sample_rate
|
267 |
+
self.min_segments = min_segments
|
268 |
+
self.max_segments = max_segments
|
269 |
+
self.min_gain_db = min_gain_db
|
270 |
+
self.max_gain_db = max_gain_db
|
271 |
+
|
272 |
+
def forward(self, x: torch.Tensor):
|
273 |
+
gain_db = torch.zeros(x.shape[-1]).type_as(x)
|
274 |
+
|
275 |
+
num_segments = randint(self.min_segments, self.max_segments)
|
276 |
+
segment_lengths = (
|
277 |
+
x.shape[-1]
|
278 |
+
* np.random.dirichlet([rand(0, 10) for _ in range(num_segments)], 1)
|
279 |
+
).astype("int")[0]
|
280 |
+
|
281 |
+
samples_filled = 0
|
282 |
+
start_gain_db = 0
|
283 |
+
for idx in range(num_segments):
|
284 |
+
segment_samples = segment_lengths[idx]
|
285 |
+
if idx != 0:
|
286 |
+
start_gain_db = end_gain_db
|
287 |
+
|
288 |
+
# sample random end gain
|
289 |
+
end_gain_db = rand(self.min_gain_db, self.max_gain_db)
|
290 |
+
fade = torch.linspace(start_gain_db, end_gain_db, steps=segment_samples)
|
291 |
+
gain_db[samples_filled : samples_filled + segment_samples] = fade
|
292 |
+
samples_filled = samples_filled + segment_samples
|
293 |
+
|
294 |
+
# print(gain_db)
|
295 |
+
x *= 10 ** (gain_db / 20.0)
|
296 |
+
return x
|
297 |
+
|
298 |
+
|
299 |
+
class RandomPedalboardCompressor(torch.nn.Module):
|
300 |
+
def __init__(
|
301 |
+
self,
|
302 |
+
sample_rate: float,
|
303 |
+
min_threshold_db: float = -42.0,
|
304 |
+
max_threshold_db: float = -6.0,
|
305 |
+
min_ratio: float = 1.5,
|
306 |
+
max_ratio: float = 4.0,
|
307 |
+
min_attack_ms: float = 1.0,
|
308 |
+
max_attack_ms: float = 50.0,
|
309 |
+
min_release_ms: float = 10.0,
|
310 |
+
max_release_ms: float = 250.0,
|
311 |
+
) -> None:
|
312 |
+
super().__init__()
|
313 |
+
self.sample_rate = sample_rate
|
314 |
+
self.min_threshold_db = min_threshold_db
|
315 |
+
self.max_threshold_db = max_threshold_db
|
316 |
+
self.min_ratio = min_ratio
|
317 |
+
self.max_ratio = max_ratio
|
318 |
+
self.min_attack_ms = min_attack_ms
|
319 |
+
self.max_attack_ms = max_attack_ms
|
320 |
+
self.min_release_ms = min_release_ms
|
321 |
+
self.max_release_ms = max_release_ms
|
322 |
+
|
323 |
+
def forward(self, x: torch.Tensor):
|
324 |
+
board = Pedalboard()
|
325 |
+
threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
|
326 |
+
ratio = rand(self.min_ratio, self.max_ratio)
|
327 |
+
attack_ms = rand(self.min_attack_ms, self.max_attack_ms)
|
328 |
+
release_ms = rand(self.min_release_ms, self.max_release_ms)
|
329 |
+
|
330 |
+
board.append(
|
331 |
+
Compressor(
|
332 |
+
threshold_db=threshold_db,
|
333 |
+
ratio=ratio,
|
334 |
+
attack_ms=attack_ms,
|
335 |
+
release_ms=release_ms,
|
336 |
+
)
|
337 |
+
)
|
338 |
+
|
339 |
+
# process audio using the pedalboard
|
340 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
341 |
+
|
342 |
+
|
343 |
+
class RandomPedalboardDelay(torch.nn.Module):
|
344 |
+
def __init__(
|
345 |
+
self,
|
346 |
+
sample_rate: float,
|
347 |
+
min_delay_seconds: float = 0.1,
|
348 |
+
max_delay_sconds: float = 1.0,
|
349 |
+
min_feedback: float = 0.05,
|
350 |
+
max_feedback: float = 0.6,
|
351 |
+
min_mix: float = 0.0,
|
352 |
+
max_mix: float = 0.7,
|
353 |
+
) -> None:
|
354 |
+
super().__init__()
|
355 |
+
self.sample_rate = sample_rate
|
356 |
+
self.min_delay_seconds = min_delay_seconds
|
357 |
+
self.max_delay_seconds = max_delay_sconds
|
358 |
+
self.min_feedback = min_feedback
|
359 |
+
self.max_feedback = max_feedback
|
360 |
+
self.min_mix = min_mix
|
361 |
+
self.max_mix = max_mix
|
362 |
+
|
363 |
+
def forward(self, x: torch.Tensor):
|
364 |
+
board = Pedalboard()
|
365 |
+
delay_seconds = loguniform(self.min_delay_seconds, self.max_delay_seconds)
|
366 |
+
feedback = rand(self.min_feedback, self.max_feedback)
|
367 |
+
mix = rand(self.min_mix, self.max_mix)
|
368 |
+
board.append(Delay(delay_seconds=delay_seconds, feedback=feedback, mix=mix))
|
369 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
370 |
+
|
371 |
+
|
372 |
+
class RandomPedalboardChorus(torch.nn.Module):
|
373 |
+
def __init__(
|
374 |
+
self,
|
375 |
+
sample_rate: float,
|
376 |
+
min_rate_hz: float = 0.25,
|
377 |
+
max_rate_hz: float = 4.0,
|
378 |
+
min_depth: float = 0.0,
|
379 |
+
max_depth: float = 0.6,
|
380 |
+
min_centre_delay_ms: float = 5.0,
|
381 |
+
max_centre_delay_ms: float = 10.0,
|
382 |
+
min_feedback: float = 0.1,
|
383 |
+
max_feedback: float = 0.6,
|
384 |
+
min_mix: float = 0.1,
|
385 |
+
max_mix: float = 0.7,
|
386 |
+
) -> None:
|
387 |
+
super().__init__()
|
388 |
+
self.sample_rate = sample_rate
|
389 |
+
self.min_rate_hz = min_rate_hz
|
390 |
+
self.max_rate_hz = max_rate_hz
|
391 |
+
self.min_depth = min_depth
|
392 |
+
self.max_depth = max_depth
|
393 |
+
self.min_centre_delay_ms = min_centre_delay_ms
|
394 |
+
self.max_centre_delay_ms = max_centre_delay_ms
|
395 |
+
self.min_feedback = min_feedback
|
396 |
+
self.max_feedback = max_feedback
|
397 |
+
self.min_mix = min_mix
|
398 |
+
self.max_mix = max_mix
|
399 |
+
|
400 |
+
def forward(self, x: torch.Tensor):
|
401 |
+
board = Pedalboard()
|
402 |
+
rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
|
403 |
+
depth = rand(self.min_depth, self.max_depth)
|
404 |
+
centre_delay_ms = rand(self.min_centre_delay_ms, self.max_centre_delay_ms)
|
405 |
+
feedback = rand(self.min_feedback, self.max_feedback)
|
406 |
+
mix = rand(self.min_mix, self.max_mix)
|
407 |
+
board.append(
|
408 |
+
Chorus(
|
409 |
+
rate_hz=rate_hz,
|
410 |
+
depth=depth,
|
411 |
+
centre_delay_ms=centre_delay_ms,
|
412 |
+
feedback=feedback,
|
413 |
+
mix=mix,
|
414 |
+
)
|
415 |
+
)
|
416 |
+
# process audio using the pedalboard
|
417 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
418 |
+
|
419 |
+
|
420 |
+
class RandomPedalboardPhaser(torch.nn.Module):
|
421 |
+
def __init__(
|
422 |
+
self,
|
423 |
+
sample_rate: float,
|
424 |
+
min_rate_hz: float = 0.25,
|
425 |
+
max_rate_hz: float = 5.0,
|
426 |
+
min_depth: float = 0.1,
|
427 |
+
max_depth: float = 0.6,
|
428 |
+
min_centre_frequency_hz: float = 200.0,
|
429 |
+
max_centre_frequency_hz: float = 600.0,
|
430 |
+
min_feedback: float = 0.1,
|
431 |
+
max_feedback: float = 0.6,
|
432 |
+
min_mix: float = 0.1,
|
433 |
+
max_mix: float = 0.7,
|
434 |
+
) -> None:
|
435 |
+
super().__init__()
|
436 |
+
self.sample_rate = sample_rate
|
437 |
+
self.min_rate_hz = min_rate_hz
|
438 |
+
self.max_rate_hz = max_rate_hz
|
439 |
+
self.min_depth = min_depth
|
440 |
+
self.max_depth = max_depth
|
441 |
+
self.min_centre_frequency_hz = min_centre_frequency_hz
|
442 |
+
self.max_centre_frequency_hz = max_centre_frequency_hz
|
443 |
+
self.min_feedback = min_feedback
|
444 |
+
self.max_feedback = max_feedback
|
445 |
+
self.min_mix = min_mix
|
446 |
+
self.max_mix = max_mix
|
447 |
+
|
448 |
+
def forward(self, x: torch.Tensor):
|
449 |
+
board = Pedalboard()
|
450 |
+
rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
|
451 |
+
depth = rand(self.min_depth, self.max_depth)
|
452 |
+
centre_frequency_hz = rand(
|
453 |
+
self.min_centre_frequency_hz, self.min_centre_frequency_hz
|
454 |
+
)
|
455 |
+
feedback = rand(self.min_feedback, self.max_feedback)
|
456 |
+
mix = rand(self.min_mix, self.max_mix)
|
457 |
+
board.append(
|
458 |
+
Phaser(
|
459 |
+
rate_hz=rate_hz,
|
460 |
+
depth=depth,
|
461 |
+
centre_frequency_hz=centre_frequency_hz,
|
462 |
+
feedback=feedback,
|
463 |
+
mix=mix,
|
464 |
+
)
|
465 |
+
)
|
466 |
+
# process audio using the pedalboard
|
467 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
468 |
+
|
469 |
+
|
470 |
+
class RandomPedalboardLimiter(torch.nn.Module):
|
471 |
+
def __init__(
|
472 |
+
self,
|
473 |
+
sample_rate: float,
|
474 |
+
min_threshold_db: float = -32.0,
|
475 |
+
max_threshold_db: float = -6.0,
|
476 |
+
min_release_ms: float = 10.0,
|
477 |
+
max_release_ms: float = 300.0,
|
478 |
+
) -> None:
|
479 |
+
super().__init__()
|
480 |
+
self.sample_rate = sample_rate
|
481 |
+
self.min_threshold_db = min_threshold_db
|
482 |
+
self.max_threshold_db = max_threshold_db
|
483 |
+
self.min_release_ms = min_release_ms
|
484 |
+
self.max_release_ms = max_release_ms
|
485 |
+
|
486 |
+
def forward(self, x: torch.Tensor):
|
487 |
+
board = Pedalboard()
|
488 |
+
threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
|
489 |
+
release_ms = rand(self.min_release_ms, self.max_release_ms)
|
490 |
+
board.append(
|
491 |
+
Limiter(
|
492 |
+
threshold_db=threshold_db,
|
493 |
+
release_ms=release_ms,
|
494 |
+
)
|
495 |
+
)
|
496 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
497 |
+
|
498 |
+
|
499 |
+
class RandomPedalboardDistortion(torch.nn.Module):
|
500 |
+
def __init__(
|
501 |
+
self,
|
502 |
+
sample_rate: float,
|
503 |
+
min_drive_db: float = -20.0,
|
504 |
+
max_drive_db: float = 12.0,
|
505 |
+
):
|
506 |
+
super().__init__()
|
507 |
+
self.sample_rate = sample_rate
|
508 |
+
self.min_drive_db = min_drive_db
|
509 |
+
self.max_drive_db = max_drive_db
|
510 |
+
|
511 |
+
def forward(self, x: torch.Tensor):
|
512 |
+
board = Pedalboard()
|
513 |
+
drive_db = rand(self.min_drive_db, self.max_drive_db)
|
514 |
+
board.append(Distortion(drive_db=drive_db))
|
515 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
516 |
+
|
517 |
+
|
518 |
+
class RandomSoxReverb(torch.nn.Module):
|
519 |
+
def __init__(
|
520 |
+
self,
|
521 |
+
sample_rate: float,
|
522 |
+
min_reverberance: float = 10.0,
|
523 |
+
max_reverberance: float = 100.0,
|
524 |
+
min_high_freq_damping: float = 0.0,
|
525 |
+
max_high_freq_damping: float = 100.0,
|
526 |
+
min_wet_dry: float = 0.0,
|
527 |
+
max_wet_dry: float = 1.0,
|
528 |
+
min_room_scale: float = 5.0,
|
529 |
+
max_room_scale: float = 100.0,
|
530 |
+
min_stereo_depth: float = 20.0,
|
531 |
+
max_stereo_depth: float = 100.0,
|
532 |
+
min_pre_delay: float = 0.0,
|
533 |
+
max_pre_delay: float = 100.0,
|
534 |
+
) -> None:
|
535 |
+
super().__init__()
|
536 |
+
self.sample_rate = sample_rate
|
537 |
+
self.min_reverberance = min_reverberance
|
538 |
+
self.max_reverberance = max_reverberance
|
539 |
+
self.min_high_freq_damping = min_high_freq_damping
|
540 |
+
self.max_high_freq_damping = max_high_freq_damping
|
541 |
+
self.min_wet_dry = min_wet_dry
|
542 |
+
self.max_wet_dry = max_wet_dry
|
543 |
+
self.min_room_scale = min_room_scale
|
544 |
+
self.max_room_scale = max_room_scale
|
545 |
+
self.min_stereo_depth = min_stereo_depth
|
546 |
+
self.max_stereo_depth = max_stereo_depth
|
547 |
+
self.min_pre_delay = min_pre_delay
|
548 |
+
self.max_pre_delay = max_pre_delay
|
549 |
+
|
550 |
+
def forward(self, x: torch.Tensor):
|
551 |
+
reverberance = rand(self.min_reverberance, self.max_reverberance)
|
552 |
+
high_freq_damping = rand(self.min_high_freq_damping, self.max_high_freq_damping)
|
553 |
+
room_scale = rand(self.min_room_scale, self.max_room_scale)
|
554 |
+
stereo_depth = rand(self.min_stereo_depth, self.max_stereo_depth)
|
555 |
+
wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
|
556 |
+
pre_delay = rand(self.min_pre_delay, self.max_pre_delay)
|
557 |
+
|
558 |
+
effects = [
|
559 |
+
[
|
560 |
+
"reverb",
|
561 |
+
f"{reverberance}",
|
562 |
+
f"{high_freq_damping}",
|
563 |
+
f"{room_scale}",
|
564 |
+
f"{stereo_depth}",
|
565 |
+
f"{pre_delay}",
|
566 |
+
"--wet-only",
|
567 |
+
]
|
568 |
+
]
|
569 |
+
y, _ = torchaudio.sox_effects.apply_effects_tensor(
|
570 |
+
x, self.sample_rate, effects, channels_first=True
|
571 |
+
)
|
572 |
+
|
573 |
+
# manual wet/dry mix
|
574 |
+
return (x * (1 - wet_dry)) + (y * wet_dry)
|
575 |
+
|
576 |
+
|
577 |
+
class RandomPebalboardReverb(torch.nn.Module):
|
578 |
+
def __init__(
|
579 |
+
self,
|
580 |
+
sample_rate: float,
|
581 |
+
min_room_size: float = 0.0,
|
582 |
+
max_room_size: float = 1.0,
|
583 |
+
min_damping: float = 0.0,
|
584 |
+
max_damping: float = 1.0,
|
585 |
+
min_wet_dry: float = 0.0,
|
586 |
+
max_wet_dry: float = 0.7,
|
587 |
+
min_width: float = 0.0,
|
588 |
+
max_width: float = 1.0,
|
589 |
+
) -> None:
|
590 |
+
super().__init__()
|
591 |
+
self.sample_rate = sample_rate
|
592 |
+
self.min_room_size = min_room_size
|
593 |
+
self.max_room_size = max_room_size
|
594 |
+
self.min_damping = min_damping
|
595 |
+
self.max_damping = max_damping
|
596 |
+
self.min_wet_dry = min_wet_dry
|
597 |
+
self.max_wet_dry = max_wet_dry
|
598 |
+
self.min_width = min_width
|
599 |
+
self.max_width = max_width
|
600 |
+
|
601 |
+
def forward(self, x: torch.Tensor):
|
602 |
+
board = Pedalboard()
|
603 |
+
room_size = rand(self.min_room_size, self.max_room_size)
|
604 |
+
damping = rand(self.min_damping, self.max_damping)
|
605 |
+
wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
|
606 |
+
width = rand(self.min_width, self.max_width)
|
607 |
+
|
608 |
+
board.append(
|
609 |
+
Reverb(
|
610 |
+
room_size=room_size,
|
611 |
+
damping=damping,
|
612 |
+
wet_level=wet_dry,
|
613 |
+
dry_level=(1 - wet_dry),
|
614 |
+
width=width,
|
615 |
+
)
|
616 |
+
)
|
617 |
+
|
618 |
+
return torch.from_numpy(board(x.numpy(), self.sample_rate))
|
619 |
+
|
620 |
+
|
621 |
+
class LoudnessNormalize(torch.nn.Module):
|
622 |
+
def __init__(self, sample_rate: float, target_lufs_db: float = -32.0) -> None:
|
623 |
+
super().__init__()
|
624 |
+
self.meter = pyln.Meter(sample_rate)
|
625 |
+
self.target_lufs_db = target_lufs_db
|
626 |
+
|
627 |
+
def forward(self, x: torch.Tensor):
|
628 |
+
x_lufs_db = self.meter.integrated_loudness(x.permute(1, 0).numpy())
|
629 |
+
delta_lufs_db = torch.tensor([self.target_lufs_db - x_lufs_db]).float()
|
630 |
+
gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
|
631 |
+
return gain_lin * x
|
632 |
+
|
633 |
+
|
634 |
+
class RandomAudioEffectsChannel(torch.nn.Module):
|
635 |
+
def __init__(
|
636 |
+
self,
|
637 |
+
sample_rate: float,
|
638 |
+
parametric_eq_prob: float = 0.7,
|
639 |
+
distortion_prob: float = 0.01,
|
640 |
+
delay_prob: float = 0.1,
|
641 |
+
chorus_prob: float = 0.01,
|
642 |
+
phaser_prob: float = 0.01,
|
643 |
+
compressor_prob: float = 0.4,
|
644 |
+
reverb_prob: float = 0.2,
|
645 |
+
stereo_widener_prob: float = 0.3,
|
646 |
+
limiter_prob: float = 0.3,
|
647 |
+
vol_automation_prob: float = 0.7,
|
648 |
+
target_lufs_db: float = -32.0,
|
649 |
+
) -> None:
|
650 |
+
super().__init__()
|
651 |
+
self.transforms = Compose(
|
652 |
+
[
|
653 |
+
RandomApply(
|
654 |
+
[RandomParametricEQ(sample_rate)],
|
655 |
+
p=parametric_eq_prob,
|
656 |
+
),
|
657 |
+
RandomApply(
|
658 |
+
[RandomPedalboardDistortion(sample_rate)],
|
659 |
+
p=distortion_prob,
|
660 |
+
),
|
661 |
+
RandomApply(
|
662 |
+
[RandomPedalboardDelay(sample_rate)],
|
663 |
+
p=delay_prob,
|
664 |
+
),
|
665 |
+
RandomApply(
|
666 |
+
[RandomPedalboardChorus(sample_rate)],
|
667 |
+
p=chorus_prob,
|
668 |
+
),
|
669 |
+
RandomApply(
|
670 |
+
[RandomPedalboardPhaser(sample_rate)],
|
671 |
+
p=phaser_prob,
|
672 |
+
),
|
673 |
+
RandomApply(
|
674 |
+
[RandomPedalboardCompressor(sample_rate)],
|
675 |
+
p=compressor_prob,
|
676 |
+
),
|
677 |
+
RandomApply(
|
678 |
+
[RandomPebalboardReverb(sample_rate)],
|
679 |
+
p=reverb_prob,
|
680 |
+
),
|
681 |
+
RandomApply(
|
682 |
+
[RandomStereoWidener(sample_rate)],
|
683 |
+
p=stereo_widener_prob,
|
684 |
+
),
|
685 |
+
RandomApply(
|
686 |
+
[RandomPedalboardLimiter(sample_rate)],
|
687 |
+
p=limiter_prob,
|
688 |
+
),
|
689 |
+
RandomApply(
|
690 |
+
[RandomVolumeAutomation(sample_rate)],
|
691 |
+
p=vol_automation_prob,
|
692 |
+
),
|
693 |
+
LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db),
|
694 |
+
]
|
695 |
+
)
|
696 |
+
|
697 |
+
def forward(self, x: torch.Tensor):
|
698 |
+
return self.transforms(x)
|
remfx/models.py
CHANGED
@@ -117,7 +117,11 @@ class RemFXModel(pl.LightningModule):
|
|
117 |
y = self.model.sample(x)
|
118 |
|
119 |
# Concat samples together for easier viewing in dashboard
|
120 |
-
|
|
|
|
|
|
|
|
|
121 |
log_wandb_audio_batch(
|
122 |
logger=self.logger,
|
123 |
id="prediction_input_target",
|
|
|
117 |
y = self.model.sample(x)
|
118 |
|
119 |
# Concat samples together for easier viewing in dashboard
|
120 |
+
# 2 seconds of silence between each sample
|
121 |
+
silence = torch.zeros_like(x)
|
122 |
+
silence = silence[:, : self.sample_rate * 2]
|
123 |
+
|
124 |
+
concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
|
125 |
log_wandb_audio_batch(
|
126 |
logger=self.logger,
|
127 |
id="prediction_input_target",
|
setup.py
CHANGED
@@ -44,6 +44,8 @@ setup(
|
|
44 |
"librosa",
|
45 |
"hydra-core",
|
46 |
"auraloss",
|
|
|
|
|
47 |
],
|
48 |
include_package_data=True,
|
49 |
license="Apache License 2.0",
|
|
|
44 |
"librosa",
|
45 |
"hydra-core",
|
46 |
"auraloss",
|
47 |
+
"pyloudnorm",
|
48 |
+
"pedalboard",
|
49 |
],
|
50 |
include_package_data=True,
|
51 |
license="Apache License 2.0",
|
shell_vars.sh
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
-
export DATASET_ROOT="./data/
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|
|
|
1 |
+
export DATASET_ROOT="./data/GuitarSet"
|
2 |
export WANDB_PROJECT="RemFX"
|
3 |
export WANDB_ENTITY="mattricesound"
|