mattricesound commited on
Commit
b175ee9
1 Parent(s): dfbeb31

Add new GuitarSet dataset. Add pedal board effects. Fix sample_rate mismatch bug

Browse files
Files changed (8) hide show
  1. config.yaml +1 -3
  2. config_guitfx.yaml +52 -0
  3. exp/umx.yaml +6 -1
  4. remfx/datasets.py +80 -21
  5. remfx/effects.py +698 -0
  6. remfx/models.py +5 -1
  7. setup.py +2 -0
  8. shell_vars.sh +1 -1
config.yaml CHANGED
@@ -3,7 +3,6 @@ defaults:
3
  - exp: null
4
  seed: 12345
5
  train: True
6
- length: 262144
7
  sample_rate: 48000
8
  logs_dir: "./logs"
9
  log_every_n_steps: 1000
@@ -22,10 +21,9 @@ callbacks:
22
  datamodule:
23
  _target_: remfx.datasets.Datamodule
24
  dataset:
25
- _target_: remfx.datasets.GuitarFXDataset
26
  sample_rate: ${sample_rate}
27
  root: ${oc.env:DATASET_ROOT}
28
- length: ${length}
29
  chunk_size_in_sec: 6
30
  val_split: 0.2
31
  batch_size: 16
 
3
  - exp: null
4
  seed: 12345
5
  train: True
 
6
  sample_rate: 48000
7
  logs_dir: "./logs"
8
  log_every_n_steps: 1000
 
21
  datamodule:
22
  _target_: remfx.datasets.Datamodule
23
  dataset:
24
+ _target_: remfx.datasets.GuitarSet
25
  sample_rate: ${sample_rate}
26
  root: ${oc.env:DATASET_ROOT}
 
27
  chunk_size_in_sec: 6
28
  val_split: 0.2
29
  batch_size: 16
config_guitfx.yaml ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - _self_
3
+ - exp: null
4
+ seed: 12345
5
+ train: True
6
+ sample_rate: 48000
7
+ logs_dir: "./logs"
8
+ log_every_n_steps: 1000
9
+
10
+ callbacks:
11
+ model_checkpoint:
12
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
13
+ monitor: "valid_loss" # name of the logged metric which determines when model is improving
14
+ save_top_k: 1 # save k best models (determined by above metric)
15
+ save_last: True # additionaly always save model from last epoch
16
+ mode: "min" # can be "max" or "min"
17
+ verbose: False
18
+ dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
19
+ filename: '{epoch:02d}-{valid_loss:.3f}'
20
+
21
+ datamodule:
22
+ _target_: remfx.datasets.Datamodule
23
+ dataset:
24
+ _target_: remfx.datasets.GuitarFXDataset
25
+ sample_rate: ${sample_rate}
26
+ root: ${oc.env:DATASET_ROOT}
27
+ chunk_size_in_sec: 6
28
+ val_split: 0.2
29
+ batch_size: 16
30
+ num_workers: 8
31
+ pin_memory: True
32
+ persistent_workers: True
33
+
34
+ logger:
35
+ _target_: pytorch_lightning.loggers.WandbLogger
36
+ project: ${oc.env:WANDB_PROJECT}
37
+ entity: ${oc.env:WANDB_ENTITY}
38
+ # offline: False # set True to store all logs only locally
39
+ job_type: "train"
40
+ group: ""
41
+ save_dir: "."
42
+
43
+ trainer:
44
+ _target_: pytorch_lightning.Trainer
45
+ precision: 32 # Precision used for tensors, default `32`
46
+ min_epochs: 0
47
+ max_epochs: -1
48
+ enable_model_summary: False
49
+ log_every_n_steps: 1 # Logs metrics every N batches
50
+ accumulate_grad_batches: 1
51
+ accelerator: null
52
+ devices: 1
exp/umx.yaml CHANGED
@@ -16,4 +16,9 @@ model:
16
  sample_rate: ${sample_rate}
17
  datamodule:
18
  dataset:
19
- effect_types: ["RAT"]
 
 
 
 
 
 
16
  sample_rate: ${sample_rate}
17
  datamodule:
18
  dataset:
19
+ effect_types:
20
+ Distortion:
21
+ _target_: remfx.effects.RandomPedalboardDistortion
22
+ sample_rate: ${sample_rate}
23
+ min_drive_db: -10
24
+ max_drive_db: 30
remfx/datasets.py CHANGED
@@ -7,10 +7,8 @@ from pathlib import Path
7
  import pytorch_lightning as pl
8
  from typing import Any, List, Tuple
9
 
10
- # https://zenodo.org/record/7044411/
11
-
12
- LENGTH = 2**18 # 12 seconds
13
- ORIG_SR = 48000
14
 
15
 
16
  class GuitarFXDataset(Dataset):
@@ -18,11 +16,10 @@ class GuitarFXDataset(Dataset):
18
  self,
19
  root: str,
20
  sample_rate: int,
21
- length: int = LENGTH,
22
  chunk_size_in_sec: int = 3,
23
  effect_types: List[str] = None,
24
  ):
25
- self.length = length
26
  self.wet_files = []
27
  self.dry_files = []
28
  self.chunks = []
@@ -30,6 +27,7 @@ class GuitarFXDataset(Dataset):
30
  self.song_idx = []
31
  self.root = Path(root)
32
  self.chunk_size_in_sec = chunk_size_in_sec
 
33
 
34
  if effect_types is None:
35
  effect_types = [
@@ -46,7 +44,7 @@ class GuitarFXDataset(Dataset):
46
  self.dry_files += dry_files
47
  self.labels += [i] * len(wet_files)
48
  for audio_file in wet_files:
49
- chunk_starts = create_sequential_chunks(
50
  audio_file, self.chunk_size_in_sec
51
  )
52
  self.chunks += chunk_starts
@@ -56,7 +54,7 @@ class GuitarFXDataset(Dataset):
56
  f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
57
  f"Total chunks: {len(self.chunks)}"
58
  )
59
- self.resampler = T.Resample(ORIG_SR, sample_rate)
60
 
61
  def __len__(self):
62
  return len(self.chunks)
@@ -75,20 +73,79 @@ class GuitarFXDataset(Dataset):
75
 
76
  resampled_x = self.resampler(x)
77
  resampled_y = self.resampler(y)
78
- # Pad to length if needed
79
- if resampled_x.shape[-1] < self.length:
80
- resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))
81
- if resampled_y.shape[-1] < self.length:
82
- resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))
 
 
 
 
 
 
83
  return (resampled_x, resampled_y, effect_label)
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def create_random_chunks(
87
  audio_file: str, chunk_size: int, num_chunks: int
88
- ) -> List[Tuple[int, int]]:
89
  """Create num_chunks random chunks of size chunk_size (seconds)
90
  from an audio file.
91
- Return sample_index of start of each chunk
92
  """
93
  audio, sr = torchaudio.load(audio_file)
94
  chunk_size_in_samples = chunk_size * sr
@@ -98,17 +155,19 @@ def create_random_chunks(
98
  for i in range(num_chunks):
99
  start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
100
  chunks.append(start)
101
- return chunks
102
 
103
 
104
- def create_sequential_chunks(audio_file: str, chunk_size: int) -> List[Tuple[int, int]]:
 
 
105
  """Create sequential chunks of size chunk_size (seconds) from an audio file.
106
- Return sample_index of start of each chunk
107
  """
108
  audio, sr = torchaudio.load(audio_file)
109
  chunk_size_in_samples = chunk_size * sr
110
  chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
111
- return chunk_starts
112
 
113
 
114
  class Datamodule(pl.LightningDataModule):
@@ -133,8 +192,8 @@ class Datamodule(pl.LightningDataModule):
133
 
134
  def setup(self, stage: Any = None) -> None:
135
  split = [1.0 - self.val_split, self.val_split]
136
- train_size = int(split[0] * len(self.dataset))
137
- val_size = int(split[1] * len(self.dataset))
138
  self.data_train, self.data_val = random_split(
139
  self.dataset, [train_size, val_size]
140
  )
 
7
  import pytorch_lightning as pl
8
  from typing import Any, List, Tuple
9
 
10
+ # https://zenodo.org/record/7044411/ -> GuitarFX
11
+ # https://zenodo.org/record/3371780 -> GuitarSet
 
 
12
 
13
 
14
  class GuitarFXDataset(Dataset):
 
16
  self,
17
  root: str,
18
  sample_rate: int,
 
19
  chunk_size_in_sec: int = 3,
20
  effect_types: List[str] = None,
21
  ):
22
+ super().__init__()
23
  self.wet_files = []
24
  self.dry_files = []
25
  self.chunks = []
 
27
  self.song_idx = []
28
  self.root = Path(root)
29
  self.chunk_size_in_sec = chunk_size_in_sec
30
+ self.sample_rate = sample_rate
31
 
32
  if effect_types is None:
33
  effect_types = [
 
44
  self.dry_files += dry_files
45
  self.labels += [i] * len(wet_files)
46
  for audio_file in wet_files:
47
+ chunk_starts, orig_sr = create_sequential_chunks(
48
  audio_file, self.chunk_size_in_sec
49
  )
50
  self.chunks += chunk_starts
 
54
  f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files.\n"
55
  f"Total chunks: {len(self.chunks)}"
56
  )
57
+ self.resampler = T.Resample(orig_sr, sample_rate)
58
 
59
  def __len__(self):
60
  return len(self.chunks)
 
73
 
74
  resampled_x = self.resampler(x)
75
  resampled_y = self.resampler(y)
76
+ # Reset chunk size to be new sample rate
77
+ chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
78
+ # Pad to chunk_size if needed
79
+ if resampled_x.shape[-1] < chunk_size_in_samples:
80
+ resampled_x = F.pad(
81
+ resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
82
+ )
83
+ if resampled_y.shape[-1] < chunk_size_in_samples:
84
+ resampled_y = F.pad(
85
+ resampled_y, (0, chunk_size_in_samples - resampled_y.shape[1])
86
+ )
87
  return (resampled_x, resampled_y, effect_label)
88
 
89
 
90
+ class GuitarSet(Dataset):
91
+ def __init__(
92
+ self,
93
+ root: str,
94
+ sample_rate: int,
95
+ chunk_size_in_sec: int = 3,
96
+ effect_types: List[torch.nn.Module] = None,
97
+ ):
98
+ super().__init__()
99
+ self.chunks = []
100
+ self.song_idx = []
101
+ self.root = Path(root)
102
+ self.chunk_size_in_sec = chunk_size_in_sec
103
+ self.files = sorted(list(self.root.glob("./**/*.wav")))
104
+ self.sample_rate = sample_rate
105
+ for i, audio_file in enumerate(self.files):
106
+ chunk_starts, orig_sr = create_sequential_chunks(
107
+ audio_file, self.chunk_size_in_sec
108
+ )
109
+ self.chunks += chunk_starts
110
+ self.song_idx += [i] * len(chunk_starts)
111
+ print(f"Found {len(self.files)} files .\n" f"Total chunks: {len(self.chunks)}")
112
+ self.resampler = T.Resample(orig_sr, sample_rate)
113
+ self.effect_types = effect_types
114
+
115
+ def __len__(self):
116
+ return len(self.chunks)
117
+
118
+ def __getitem__(self, idx):
119
+ # Load and effect audio
120
+ song_idx = self.song_idx[idx]
121
+ x, sr = torchaudio.load(self.files[song_idx])
122
+ chunk_start = self.chunks[idx]
123
+ chunk_size_in_samples = self.chunk_size_in_sec * sr
124
+ x = x[:, chunk_start : chunk_start + chunk_size_in_samples]
125
+ resampled_x = self.resampler(x)
126
+ # Reset chunk size to be new sample rate
127
+ chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
128
+ # Pad to chunk_size if needed
129
+ if resampled_x.shape[-1] < chunk_size_in_samples:
130
+ resampled_x = F.pad(
131
+ resampled_x, (0, chunk_size_in_samples - resampled_x.shape[1])
132
+ )
133
+ target = resampled_x
134
+
135
+ # Add random effect
136
+ random_effect_idx = torch.rand(1).item() * len(self.effect_types.keys())
137
+ effect_name = list(self.effect_types.keys())[int(random_effect_idx)]
138
+ effect = self.effect_types[effect_name]
139
+ effected_input = effect(resampled_x)
140
+ return (effected_input, target, effect_name)
141
+
142
+
143
  def create_random_chunks(
144
  audio_file: str, chunk_size: int, num_chunks: int
145
+ ) -> Tuple[List[Tuple[int, int]], int]:
146
  """Create num_chunks random chunks of size chunk_size (seconds)
147
  from an audio file.
148
+ Return sample_index of start of each chunk and original sr
149
  """
150
  audio, sr = torchaudio.load(audio_file)
151
  chunk_size_in_samples = chunk_size * sr
 
155
  for i in range(num_chunks):
156
  start = torch.randint(0, audio.shape[-1] - chunk_size_in_samples, (1,)).item()
157
  chunks.append(start)
158
+ return chunks, sr
159
 
160
 
161
+ def create_sequential_chunks(
162
+ audio_file: str, chunk_size: int
163
+ ) -> Tuple[List[Tuple[int, int]], int]:
164
  """Create sequential chunks of size chunk_size (seconds) from an audio file.
165
+ Return sample_index of start of each chunk and original sr
166
  """
167
  audio, sr = torchaudio.load(audio_file)
168
  chunk_size_in_samples = chunk_size * sr
169
  chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
170
+ return chunk_starts, sr
171
 
172
 
173
  class Datamodule(pl.LightningDataModule):
 
192
 
193
  def setup(self, stage: Any = None) -> None:
194
  split = [1.0 - self.val_split, self.val_split]
195
+ train_size = round(split[0] * len(self.dataset))
196
+ val_size = round(split[1] * len(self.dataset))
197
  self.data_train, self.data_val = random_split(
198
  self.dataset, [train_size, val_size]
199
  )
remfx/effects.py ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import numpy as np
4
+ import scipy.signal
5
+ import scipy.stats
6
+ import pyloudnorm as pyln
7
+ from torchvision.transforms import Compose, RandomApply
8
+
9
+
10
+ from typing import List
11
+ from pedalboard import (
12
+ Pedalboard,
13
+ Chorus,
14
+ Reverb,
15
+ Compressor,
16
+ Phaser,
17
+ Delay,
18
+ Distortion,
19
+ Limiter,
20
+ )
21
+
22
+ __all__ = []
23
+
24
+
25
+ def loguniform(low=0, high=1):
26
+ return scipy.stats.loguniform.rvs(low, high)
27
+
28
+
29
+ def rand(low=0, high=1):
30
+ return (torch.rand(1).numpy()[0] * (high - low)) + low
31
+
32
+
33
+ def randint(low=0, high=1):
34
+ return torch.randint(low, high + 1, (1,)).numpy()[0]
35
+
36
+
37
+ def biqaud(
38
+ gain_db: float,
39
+ cutoff_freq: float,
40
+ q_factor: float,
41
+ sample_rate: float,
42
+ filter_type: str,
43
+ ):
44
+ """Use design parameters to generate coeffieicnets for a specific filter type.
45
+ Args:
46
+ gain_db (float): Shelving filter gain in dB.
47
+ cutoff_freq (float): Cutoff frequency in Hz.
48
+ q_factor (float): Q factor.
49
+ sample_rate (float): Sample rate in Hz.
50
+ filter_type (str): Filter type.
51
+ One of "low_shelf", "high_shelf", or "peaking"
52
+ Returns:
53
+ b (np.ndarray): Numerator filter coefficients stored as [b0, b1, b2]
54
+ a (np.ndarray): Denominator filter coefficients stored as [a0, a1, a2]
55
+ """
56
+
57
+ A = 10 ** (gain_db / 40.0)
58
+ w0 = 2.0 * np.pi * (cutoff_freq / sample_rate)
59
+ alpha = np.sin(w0) / (2.0 * q_factor)
60
+
61
+ cos_w0 = np.cos(w0)
62
+ sqrt_A = np.sqrt(A)
63
+
64
+ if filter_type == "high_shelf":
65
+ b0 = A * ((A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
66
+ b1 = -2 * A * ((A - 1) + (A + 1) * cos_w0)
67
+ b2 = A * ((A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
68
+ a0 = (A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha
69
+ a1 = 2 * ((A - 1) - (A + 1) * cos_w0)
70
+ a2 = (A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha
71
+ elif filter_type == "low_shelf":
72
+ b0 = A * ((A + 1) - (A - 1) * cos_w0 + 2 * sqrt_A * alpha)
73
+ b1 = 2 * A * ((A - 1) - (A + 1) * cos_w0)
74
+ b2 = A * ((A + 1) - (A - 1) * cos_w0 - 2 * sqrt_A * alpha)
75
+ a0 = (A + 1) + (A - 1) * cos_w0 + 2 * sqrt_A * alpha
76
+ a1 = -2 * ((A - 1) + (A + 1) * cos_w0)
77
+ a2 = (A + 1) + (A - 1) * cos_w0 - 2 * sqrt_A * alpha
78
+ elif filter_type == "peaking":
79
+ b0 = 1 + alpha * A
80
+ b1 = -2 * cos_w0
81
+ b2 = 1 - alpha * A
82
+ a0 = 1 + alpha / A
83
+ a1 = -2 * cos_w0
84
+ a2 = 1 - alpha / A
85
+ else:
86
+ pass
87
+ # raise ValueError(f"Invalid filter_type: {filter_type}.")
88
+
89
+ b = np.array([b0, b1, b2]) / a0
90
+ a = np.array([a0, a1, a2]) / a0
91
+
92
+ return b, a
93
+
94
+
95
+ def parametric_eq(
96
+ x: np.ndarray,
97
+ sample_rate: float,
98
+ low_shelf_gain_db: float = 0.0,
99
+ low_shelf_cutoff_freq: float = 80.0,
100
+ low_shelf_q_factor: float = 0.707,
101
+ band_gains_db: List[float] = [0.0],
102
+ band_cutoff_freqs: List[float] = [300.0],
103
+ band_q_factors: List[float] = [0.707],
104
+ high_shelf_gain_db: float = 0.0,
105
+ high_shelf_cutoff_freq: float = 1000.0,
106
+ high_shelf_q_factor: float = 0.707,
107
+ dtype=np.float32,
108
+ ):
109
+ """Multiband parametric EQ.
110
+ Low-shelf -> Band 1 -> ... -> Band N -> High-shelf
111
+ Args:
112
+ """
113
+ assert (
114
+ len(band_gains_db) == len(band_cutoff_freqs) == len(band_q_factors)
115
+ ) # must define for all bands
116
+
117
+ # -------- apply low-shelf filter --------
118
+ b, a = biqaud(
119
+ low_shelf_gain_db,
120
+ low_shelf_cutoff_freq,
121
+ low_shelf_q_factor,
122
+ sample_rate,
123
+ "low_shelf",
124
+ )
125
+ x = scipy.signal.lfilter(b, a, x)
126
+
127
+ # -------- apply peaking filters --------
128
+ for gain_db, cutoff_freq, q_factor in zip(
129
+ band_gains_db, band_cutoff_freqs, band_q_factors
130
+ ):
131
+ b, a = biqaud(
132
+ gain_db,
133
+ cutoff_freq,
134
+ q_factor,
135
+ sample_rate,
136
+ "peaking",
137
+ )
138
+ x = scipy.signal.lfilter(b, a, x)
139
+
140
+ # -------- apply high-shelf filter --------
141
+ b, a = biqaud(
142
+ high_shelf_gain_db,
143
+ high_shelf_cutoff_freq,
144
+ high_shelf_q_factor,
145
+ sample_rate,
146
+ "high_shelf",
147
+ )
148
+ sos5 = np.concatenate((b, a))
149
+ x = scipy.signal.lfilter(b, a, x)
150
+
151
+ return x.astype(dtype)
152
+
153
+
154
+ class RandomParametricEQ(torch.nn.Module):
155
+ def __init__(
156
+ self,
157
+ sample_rate: float,
158
+ num_bands: int = 3,
159
+ min_gain_db: float = -6.0,
160
+ max_gain_db: float = +6.0,
161
+ min_cutoff_freq: float = 1000.0,
162
+ max_cutoff_freq: float = 10000.0,
163
+ min_q_factor: float = 0.1,
164
+ max_q_factor: float = 4.0,
165
+ ):
166
+ super().__init__()
167
+ self.sample_rate = sample_rate
168
+ self.num_bands = num_bands
169
+ self.min_gain_db = min_gain_db
170
+ self.max_gain_db = max_gain_db
171
+ self.min_cutoff_freq = min_cutoff_freq
172
+ self.max_cutoff_freq = max_cutoff_freq
173
+ self.min_q_factor = min_q_factor
174
+ self.max_q_factor = max_q_factor
175
+
176
+ def forward(self, x: torch.Tensor):
177
+ """
178
+ Args:
179
+ x: (torch.Tensor): Array of audio samples with shape (chs, seq_leq).
180
+ The filter will be applied the final dimension, and by default the same
181
+ filter will be applied to all channels.
182
+ """
183
+ low_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
184
+ low_shelf_cutoff_freq = loguniform(20.0, 200.0)
185
+ low_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
186
+
187
+ high_shelf_gain_db = rand(self.min_gain_db, self.max_gain_db)
188
+ high_shelf_cutoff_freq = loguniform(8000.0, 16000.0)
189
+ high_shelf_q_factor = rand(self.min_q_factor, self.max_q_factor)
190
+
191
+ band_gain_dbs = []
192
+ band_cutoff_freqs = []
193
+ band_q_factors = []
194
+ for _ in range(self.num_bands):
195
+ band_gain_dbs.append(rand(self.min_gain_db, self.max_gain_db))
196
+ band_cutoff_freqs.append(
197
+ loguniform(self.min_cutoff_freq, self.max_cutoff_freq)
198
+ )
199
+ band_q_factors.append(rand(self.min_q_factor, self.max_q_factor))
200
+
201
+ y = parametric_eq(
202
+ x.numpy(),
203
+ self.sample_rate,
204
+ low_shelf_gain_db=low_shelf_gain_db,
205
+ low_shelf_cutoff_freq=low_shelf_cutoff_freq,
206
+ low_shelf_q_factor=low_shelf_q_factor,
207
+ band_gains_db=band_gain_dbs,
208
+ band_cutoff_freqs=band_cutoff_freqs,
209
+ band_q_factors=band_q_factors,
210
+ high_shelf_gain_db=high_shelf_gain_db,
211
+ high_shelf_cutoff_freq=high_shelf_cutoff_freq,
212
+ high_shelf_q_factor=high_shelf_q_factor,
213
+ )
214
+
215
+ return torch.from_numpy(y)
216
+
217
+
218
+ def stereo_widener(x: torch.Tensor, width: torch.Tensor):
219
+ sqrt2 = np.sqrt(2)
220
+
221
+ left = x[0, ...]
222
+ right = x[1, ...]
223
+
224
+ mid = (left + right) / sqrt2
225
+ side = (left - right) / sqrt2
226
+
227
+ # amplify mid and side signal seperately:
228
+ mid *= 2 * (1 - width)
229
+ side *= 2 * width
230
+
231
+ left = (mid + side) / sqrt2
232
+ right = (mid - side) / sqrt2
233
+
234
+ x = torch.stack((left, right), dim=0)
235
+
236
+ return x
237
+
238
+
239
+ class RandomStereoWidener(torch.nn.Module):
240
+ def __init__(
241
+ self,
242
+ sample_rate: float,
243
+ min_width: float = 0.0,
244
+ max_width: float = 1.0,
245
+ ) -> None:
246
+ super().__init__()
247
+ self.sample_rate = sample_rate
248
+ self.min_width = min_width
249
+ self.max_width = max_width
250
+
251
+ def forward(self, x: torch.Tensor):
252
+ width = rand(self.min_width, self.max_width)
253
+ return stereo_widener(x, width)
254
+
255
+
256
+ class RandomVolumeAutomation(torch.nn.Module):
257
+ def __init__(
258
+ self,
259
+ sample_rate: float,
260
+ min_segments: int = 1,
261
+ max_segments: int = 3,
262
+ min_gain_db: float = -6.0,
263
+ max_gain_db: float = 6.0,
264
+ ) -> None:
265
+ super().__init__()
266
+ self.sample_rate = sample_rate
267
+ self.min_segments = min_segments
268
+ self.max_segments = max_segments
269
+ self.min_gain_db = min_gain_db
270
+ self.max_gain_db = max_gain_db
271
+
272
+ def forward(self, x: torch.Tensor):
273
+ gain_db = torch.zeros(x.shape[-1]).type_as(x)
274
+
275
+ num_segments = randint(self.min_segments, self.max_segments)
276
+ segment_lengths = (
277
+ x.shape[-1]
278
+ * np.random.dirichlet([rand(0, 10) for _ in range(num_segments)], 1)
279
+ ).astype("int")[0]
280
+
281
+ samples_filled = 0
282
+ start_gain_db = 0
283
+ for idx in range(num_segments):
284
+ segment_samples = segment_lengths[idx]
285
+ if idx != 0:
286
+ start_gain_db = end_gain_db
287
+
288
+ # sample random end gain
289
+ end_gain_db = rand(self.min_gain_db, self.max_gain_db)
290
+ fade = torch.linspace(start_gain_db, end_gain_db, steps=segment_samples)
291
+ gain_db[samples_filled : samples_filled + segment_samples] = fade
292
+ samples_filled = samples_filled + segment_samples
293
+
294
+ # print(gain_db)
295
+ x *= 10 ** (gain_db / 20.0)
296
+ return x
297
+
298
+
299
+ class RandomPedalboardCompressor(torch.nn.Module):
300
+ def __init__(
301
+ self,
302
+ sample_rate: float,
303
+ min_threshold_db: float = -42.0,
304
+ max_threshold_db: float = -6.0,
305
+ min_ratio: float = 1.5,
306
+ max_ratio: float = 4.0,
307
+ min_attack_ms: float = 1.0,
308
+ max_attack_ms: float = 50.0,
309
+ min_release_ms: float = 10.0,
310
+ max_release_ms: float = 250.0,
311
+ ) -> None:
312
+ super().__init__()
313
+ self.sample_rate = sample_rate
314
+ self.min_threshold_db = min_threshold_db
315
+ self.max_threshold_db = max_threshold_db
316
+ self.min_ratio = min_ratio
317
+ self.max_ratio = max_ratio
318
+ self.min_attack_ms = min_attack_ms
319
+ self.max_attack_ms = max_attack_ms
320
+ self.min_release_ms = min_release_ms
321
+ self.max_release_ms = max_release_ms
322
+
323
+ def forward(self, x: torch.Tensor):
324
+ board = Pedalboard()
325
+ threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
326
+ ratio = rand(self.min_ratio, self.max_ratio)
327
+ attack_ms = rand(self.min_attack_ms, self.max_attack_ms)
328
+ release_ms = rand(self.min_release_ms, self.max_release_ms)
329
+
330
+ board.append(
331
+ Compressor(
332
+ threshold_db=threshold_db,
333
+ ratio=ratio,
334
+ attack_ms=attack_ms,
335
+ release_ms=release_ms,
336
+ )
337
+ )
338
+
339
+ # process audio using the pedalboard
340
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
341
+
342
+
343
+ class RandomPedalboardDelay(torch.nn.Module):
344
+ def __init__(
345
+ self,
346
+ sample_rate: float,
347
+ min_delay_seconds: float = 0.1,
348
+ max_delay_sconds: float = 1.0,
349
+ min_feedback: float = 0.05,
350
+ max_feedback: float = 0.6,
351
+ min_mix: float = 0.0,
352
+ max_mix: float = 0.7,
353
+ ) -> None:
354
+ super().__init__()
355
+ self.sample_rate = sample_rate
356
+ self.min_delay_seconds = min_delay_seconds
357
+ self.max_delay_seconds = max_delay_sconds
358
+ self.min_feedback = min_feedback
359
+ self.max_feedback = max_feedback
360
+ self.min_mix = min_mix
361
+ self.max_mix = max_mix
362
+
363
+ def forward(self, x: torch.Tensor):
364
+ board = Pedalboard()
365
+ delay_seconds = loguniform(self.min_delay_seconds, self.max_delay_seconds)
366
+ feedback = rand(self.min_feedback, self.max_feedback)
367
+ mix = rand(self.min_mix, self.max_mix)
368
+ board.append(Delay(delay_seconds=delay_seconds, feedback=feedback, mix=mix))
369
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
370
+
371
+
372
+ class RandomPedalboardChorus(torch.nn.Module):
373
+ def __init__(
374
+ self,
375
+ sample_rate: float,
376
+ min_rate_hz: float = 0.25,
377
+ max_rate_hz: float = 4.0,
378
+ min_depth: float = 0.0,
379
+ max_depth: float = 0.6,
380
+ min_centre_delay_ms: float = 5.0,
381
+ max_centre_delay_ms: float = 10.0,
382
+ min_feedback: float = 0.1,
383
+ max_feedback: float = 0.6,
384
+ min_mix: float = 0.1,
385
+ max_mix: float = 0.7,
386
+ ) -> None:
387
+ super().__init__()
388
+ self.sample_rate = sample_rate
389
+ self.min_rate_hz = min_rate_hz
390
+ self.max_rate_hz = max_rate_hz
391
+ self.min_depth = min_depth
392
+ self.max_depth = max_depth
393
+ self.min_centre_delay_ms = min_centre_delay_ms
394
+ self.max_centre_delay_ms = max_centre_delay_ms
395
+ self.min_feedback = min_feedback
396
+ self.max_feedback = max_feedback
397
+ self.min_mix = min_mix
398
+ self.max_mix = max_mix
399
+
400
+ def forward(self, x: torch.Tensor):
401
+ board = Pedalboard()
402
+ rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
403
+ depth = rand(self.min_depth, self.max_depth)
404
+ centre_delay_ms = rand(self.min_centre_delay_ms, self.max_centre_delay_ms)
405
+ feedback = rand(self.min_feedback, self.max_feedback)
406
+ mix = rand(self.min_mix, self.max_mix)
407
+ board.append(
408
+ Chorus(
409
+ rate_hz=rate_hz,
410
+ depth=depth,
411
+ centre_delay_ms=centre_delay_ms,
412
+ feedback=feedback,
413
+ mix=mix,
414
+ )
415
+ )
416
+ # process audio using the pedalboard
417
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
418
+
419
+
420
+ class RandomPedalboardPhaser(torch.nn.Module):
421
+ def __init__(
422
+ self,
423
+ sample_rate: float,
424
+ min_rate_hz: float = 0.25,
425
+ max_rate_hz: float = 5.0,
426
+ min_depth: float = 0.1,
427
+ max_depth: float = 0.6,
428
+ min_centre_frequency_hz: float = 200.0,
429
+ max_centre_frequency_hz: float = 600.0,
430
+ min_feedback: float = 0.1,
431
+ max_feedback: float = 0.6,
432
+ min_mix: float = 0.1,
433
+ max_mix: float = 0.7,
434
+ ) -> None:
435
+ super().__init__()
436
+ self.sample_rate = sample_rate
437
+ self.min_rate_hz = min_rate_hz
438
+ self.max_rate_hz = max_rate_hz
439
+ self.min_depth = min_depth
440
+ self.max_depth = max_depth
441
+ self.min_centre_frequency_hz = min_centre_frequency_hz
442
+ self.max_centre_frequency_hz = max_centre_frequency_hz
443
+ self.min_feedback = min_feedback
444
+ self.max_feedback = max_feedback
445
+ self.min_mix = min_mix
446
+ self.max_mix = max_mix
447
+
448
+ def forward(self, x: torch.Tensor):
449
+ board = Pedalboard()
450
+ rate_hz = rand(self.min_rate_hz, self.max_rate_hz)
451
+ depth = rand(self.min_depth, self.max_depth)
452
+ centre_frequency_hz = rand(
453
+ self.min_centre_frequency_hz, self.min_centre_frequency_hz
454
+ )
455
+ feedback = rand(self.min_feedback, self.max_feedback)
456
+ mix = rand(self.min_mix, self.max_mix)
457
+ board.append(
458
+ Phaser(
459
+ rate_hz=rate_hz,
460
+ depth=depth,
461
+ centre_frequency_hz=centre_frequency_hz,
462
+ feedback=feedback,
463
+ mix=mix,
464
+ )
465
+ )
466
+ # process audio using the pedalboard
467
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
468
+
469
+
470
+ class RandomPedalboardLimiter(torch.nn.Module):
471
+ def __init__(
472
+ self,
473
+ sample_rate: float,
474
+ min_threshold_db: float = -32.0,
475
+ max_threshold_db: float = -6.0,
476
+ min_release_ms: float = 10.0,
477
+ max_release_ms: float = 300.0,
478
+ ) -> None:
479
+ super().__init__()
480
+ self.sample_rate = sample_rate
481
+ self.min_threshold_db = min_threshold_db
482
+ self.max_threshold_db = max_threshold_db
483
+ self.min_release_ms = min_release_ms
484
+ self.max_release_ms = max_release_ms
485
+
486
+ def forward(self, x: torch.Tensor):
487
+ board = Pedalboard()
488
+ threshold_db = rand(self.min_threshold_db, self.max_threshold_db)
489
+ release_ms = rand(self.min_release_ms, self.max_release_ms)
490
+ board.append(
491
+ Limiter(
492
+ threshold_db=threshold_db,
493
+ release_ms=release_ms,
494
+ )
495
+ )
496
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
497
+
498
+
499
+ class RandomPedalboardDistortion(torch.nn.Module):
500
+ def __init__(
501
+ self,
502
+ sample_rate: float,
503
+ min_drive_db: float = -20.0,
504
+ max_drive_db: float = 12.0,
505
+ ):
506
+ super().__init__()
507
+ self.sample_rate = sample_rate
508
+ self.min_drive_db = min_drive_db
509
+ self.max_drive_db = max_drive_db
510
+
511
+ def forward(self, x: torch.Tensor):
512
+ board = Pedalboard()
513
+ drive_db = rand(self.min_drive_db, self.max_drive_db)
514
+ board.append(Distortion(drive_db=drive_db))
515
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
516
+
517
+
518
+ class RandomSoxReverb(torch.nn.Module):
519
+ def __init__(
520
+ self,
521
+ sample_rate: float,
522
+ min_reverberance: float = 10.0,
523
+ max_reverberance: float = 100.0,
524
+ min_high_freq_damping: float = 0.0,
525
+ max_high_freq_damping: float = 100.0,
526
+ min_wet_dry: float = 0.0,
527
+ max_wet_dry: float = 1.0,
528
+ min_room_scale: float = 5.0,
529
+ max_room_scale: float = 100.0,
530
+ min_stereo_depth: float = 20.0,
531
+ max_stereo_depth: float = 100.0,
532
+ min_pre_delay: float = 0.0,
533
+ max_pre_delay: float = 100.0,
534
+ ) -> None:
535
+ super().__init__()
536
+ self.sample_rate = sample_rate
537
+ self.min_reverberance = min_reverberance
538
+ self.max_reverberance = max_reverberance
539
+ self.min_high_freq_damping = min_high_freq_damping
540
+ self.max_high_freq_damping = max_high_freq_damping
541
+ self.min_wet_dry = min_wet_dry
542
+ self.max_wet_dry = max_wet_dry
543
+ self.min_room_scale = min_room_scale
544
+ self.max_room_scale = max_room_scale
545
+ self.min_stereo_depth = min_stereo_depth
546
+ self.max_stereo_depth = max_stereo_depth
547
+ self.min_pre_delay = min_pre_delay
548
+ self.max_pre_delay = max_pre_delay
549
+
550
+ def forward(self, x: torch.Tensor):
551
+ reverberance = rand(self.min_reverberance, self.max_reverberance)
552
+ high_freq_damping = rand(self.min_high_freq_damping, self.max_high_freq_damping)
553
+ room_scale = rand(self.min_room_scale, self.max_room_scale)
554
+ stereo_depth = rand(self.min_stereo_depth, self.max_stereo_depth)
555
+ wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
556
+ pre_delay = rand(self.min_pre_delay, self.max_pre_delay)
557
+
558
+ effects = [
559
+ [
560
+ "reverb",
561
+ f"{reverberance}",
562
+ f"{high_freq_damping}",
563
+ f"{room_scale}",
564
+ f"{stereo_depth}",
565
+ f"{pre_delay}",
566
+ "--wet-only",
567
+ ]
568
+ ]
569
+ y, _ = torchaudio.sox_effects.apply_effects_tensor(
570
+ x, self.sample_rate, effects, channels_first=True
571
+ )
572
+
573
+ # manual wet/dry mix
574
+ return (x * (1 - wet_dry)) + (y * wet_dry)
575
+
576
+
577
+ class RandomPebalboardReverb(torch.nn.Module):
578
+ def __init__(
579
+ self,
580
+ sample_rate: float,
581
+ min_room_size: float = 0.0,
582
+ max_room_size: float = 1.0,
583
+ min_damping: float = 0.0,
584
+ max_damping: float = 1.0,
585
+ min_wet_dry: float = 0.0,
586
+ max_wet_dry: float = 0.7,
587
+ min_width: float = 0.0,
588
+ max_width: float = 1.0,
589
+ ) -> None:
590
+ super().__init__()
591
+ self.sample_rate = sample_rate
592
+ self.min_room_size = min_room_size
593
+ self.max_room_size = max_room_size
594
+ self.min_damping = min_damping
595
+ self.max_damping = max_damping
596
+ self.min_wet_dry = min_wet_dry
597
+ self.max_wet_dry = max_wet_dry
598
+ self.min_width = min_width
599
+ self.max_width = max_width
600
+
601
+ def forward(self, x: torch.Tensor):
602
+ board = Pedalboard()
603
+ room_size = rand(self.min_room_size, self.max_room_size)
604
+ damping = rand(self.min_damping, self.max_damping)
605
+ wet_dry = rand(self.min_wet_dry, self.max_wet_dry)
606
+ width = rand(self.min_width, self.max_width)
607
+
608
+ board.append(
609
+ Reverb(
610
+ room_size=room_size,
611
+ damping=damping,
612
+ wet_level=wet_dry,
613
+ dry_level=(1 - wet_dry),
614
+ width=width,
615
+ )
616
+ )
617
+
618
+ return torch.from_numpy(board(x.numpy(), self.sample_rate))
619
+
620
+
621
+ class LoudnessNormalize(torch.nn.Module):
622
+ def __init__(self, sample_rate: float, target_lufs_db: float = -32.0) -> None:
623
+ super().__init__()
624
+ self.meter = pyln.Meter(sample_rate)
625
+ self.target_lufs_db = target_lufs_db
626
+
627
+ def forward(self, x: torch.Tensor):
628
+ x_lufs_db = self.meter.integrated_loudness(x.permute(1, 0).numpy())
629
+ delta_lufs_db = torch.tensor([self.target_lufs_db - x_lufs_db]).float()
630
+ gain_lin = 10.0 ** (delta_lufs_db.clamp(-120, 40.0) / 20.0)
631
+ return gain_lin * x
632
+
633
+
634
+ class RandomAudioEffectsChannel(torch.nn.Module):
635
+ def __init__(
636
+ self,
637
+ sample_rate: float,
638
+ parametric_eq_prob: float = 0.7,
639
+ distortion_prob: float = 0.01,
640
+ delay_prob: float = 0.1,
641
+ chorus_prob: float = 0.01,
642
+ phaser_prob: float = 0.01,
643
+ compressor_prob: float = 0.4,
644
+ reverb_prob: float = 0.2,
645
+ stereo_widener_prob: float = 0.3,
646
+ limiter_prob: float = 0.3,
647
+ vol_automation_prob: float = 0.7,
648
+ target_lufs_db: float = -32.0,
649
+ ) -> None:
650
+ super().__init__()
651
+ self.transforms = Compose(
652
+ [
653
+ RandomApply(
654
+ [RandomParametricEQ(sample_rate)],
655
+ p=parametric_eq_prob,
656
+ ),
657
+ RandomApply(
658
+ [RandomPedalboardDistortion(sample_rate)],
659
+ p=distortion_prob,
660
+ ),
661
+ RandomApply(
662
+ [RandomPedalboardDelay(sample_rate)],
663
+ p=delay_prob,
664
+ ),
665
+ RandomApply(
666
+ [RandomPedalboardChorus(sample_rate)],
667
+ p=chorus_prob,
668
+ ),
669
+ RandomApply(
670
+ [RandomPedalboardPhaser(sample_rate)],
671
+ p=phaser_prob,
672
+ ),
673
+ RandomApply(
674
+ [RandomPedalboardCompressor(sample_rate)],
675
+ p=compressor_prob,
676
+ ),
677
+ RandomApply(
678
+ [RandomPebalboardReverb(sample_rate)],
679
+ p=reverb_prob,
680
+ ),
681
+ RandomApply(
682
+ [RandomStereoWidener(sample_rate)],
683
+ p=stereo_widener_prob,
684
+ ),
685
+ RandomApply(
686
+ [RandomPedalboardLimiter(sample_rate)],
687
+ p=limiter_prob,
688
+ ),
689
+ RandomApply(
690
+ [RandomVolumeAutomation(sample_rate)],
691
+ p=vol_automation_prob,
692
+ ),
693
+ LoudnessNormalize(sample_rate, target_lufs_db=target_lufs_db),
694
+ ]
695
+ )
696
+
697
+ def forward(self, x: torch.Tensor):
698
+ return self.transforms(x)
remfx/models.py CHANGED
@@ -117,7 +117,11 @@ class RemFXModel(pl.LightningModule):
117
  y = self.model.sample(x)
118
 
119
  # Concat samples together for easier viewing in dashboard
120
- concat_samples = torch.cat([y, x, target], dim=-1)
 
 
 
 
121
  log_wandb_audio_batch(
122
  logger=self.logger,
123
  id="prediction_input_target",
 
117
  y = self.model.sample(x)
118
 
119
  # Concat samples together for easier viewing in dashboard
120
+ # 2 seconds of silence between each sample
121
+ silence = torch.zeros_like(x)
122
+ silence = silence[:, : self.sample_rate * 2]
123
+
124
+ concat_samples = torch.cat([y, silence, x, silence, target], dim=-1)
125
  log_wandb_audio_batch(
126
  logger=self.logger,
127
  id="prediction_input_target",
setup.py CHANGED
@@ -44,6 +44,8 @@ setup(
44
  "librosa",
45
  "hydra-core",
46
  "auraloss",
 
 
47
  ],
48
  include_package_data=True,
49
  license="Apache License 2.0",
 
44
  "librosa",
45
  "hydra-core",
46
  "auraloss",
47
+ "pyloudnorm",
48
+ "pedalboard",
49
  ],
50
  include_package_data=True,
51
  license="Apache License 2.0",
shell_vars.sh CHANGED
@@ -1,3 +1,3 @@
1
- export DATASET_ROOT="./data/egfx"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"
 
1
+ export DATASET_ROOT="./data/GuitarSet"
2
  export WANDB_PROJECT="RemFX"
3
  export WANDB_ENTITY="mattricesound"