mattricesound commited on
Commit
14ae0ea
1 Parent(s): a22b103

WIP: Initial pipeline scripts

Browse files
Files changed (8) hide show
  1. .gitignore +1 -0
  2. README.md +5 -2
  3. datasets.py +61 -0
  4. download_egfx.sh +21 -0
  5. egfx.ipynb +0 -0
  6. guitar_generation_test.ipynb +0 -0
  7. models.py +105 -0
  8. train.py +32 -0
.gitignore CHANGED
@@ -4,3 +4,4 @@ wandb/
4
  *.egg-info/
5
  data/
6
  .DS_Store
 
 
4
  *.egg-info/
5
  data/
6
  .DS_Store
7
+ __pycache__/
README.md CHANGED
@@ -1,4 +1,7 @@
1
 
2
- wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 Clean.zip unzip Clean.zip
3
 
4
- python3 -m venv env pip install -e .
 
 
 
 
1
 
2
+ wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 Clean.zip
3
 
4
+ unzip Clean.zip
5
+
6
+ python3 -m venv env
7
+ pip install -e .
datasets.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import torchaudio
4
+ import torchaudio.transforms as T
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import List
8
+
9
+ # https://zenodo.org/record/7044411/
10
+
11
+ LENGTH = 2**18 # 12 seconds
12
+ ORIG_SR = 48000
13
+
14
+
15
+ class GuitarFXDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ root: str,
19
+ sample_rate: int,
20
+ length: int = LENGTH,
21
+ effect_type: List[str] = None,
22
+ ):
23
+ self.length = length
24
+ self.wet_files = []
25
+ self.dry_files = []
26
+ self.labels = []
27
+ self.root = Path(root)
28
+ if effect_type is None:
29
+ effect_type = [
30
+ d.name for d in self.root.iterdir() if d.is_dir() and d != "Clean"
31
+ ]
32
+ for i, effect in enumerate(effect_type):
33
+ for pickup in Path(self.root / effect).iterdir():
34
+ self.wet_files += list(pickup.glob("*.wav"))
35
+ self.dry_files += list(self.root.glob(f"Clean/{pickup.name}/**/*.wav"))
36
+ self.labels += [i] * len(self.wet_files)
37
+ print(
38
+ f"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files"
39
+ )
40
+ self.resampler = T.Resample(ORIG_SR, sample_rate)
41
+
42
+ def __len__(self):
43
+ return len(self.dry_files)
44
+
45
+ def __getitem__(self, idx):
46
+ x, sr = torchaudio.load(self.wet_files[idx])
47
+ y, sr = torchaudio.load(self.dry_files[idx])
48
+ effect_label = self.labels[idx]
49
+
50
+ resampled_x = self.resampler(x)
51
+ resampled_y = self.resampler(y)
52
+ # Pad or crop to length
53
+ if resampled_x.shape[-1] < self.length:
54
+ resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))
55
+ elif resampled_x.shape[-1] > self.length:
56
+ resampled_x = resampled_x[:, : self.length]
57
+ if resampled_y.shape[-1] < self.length:
58
+ resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))
59
+ elif resampled_y.shape[-1] > self.length:
60
+ resampled_y = resampled_y[:, : self.length]
61
+ return (resampled_x, resampled_y, effect_label)
download_egfx.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #/bin/bash
2
+ mkdir -p data
3
+ cd data
4
+ mkdir -p egfx
5
+ cd egfx
6
+ wget https://zenodo.org/record/7044411/files/BluesDriver.zip?download=1 -O BluesDriver.zip
7
+ wget https://zenodo.org/record/7044411/files/Chorus.zip?download=1 -O Chorus.zip
8
+ wget https://zenodo.org/record/7044411/files/Clean.zip?download=1 -O Clean.zip
9
+ wget https://zenodo.org/record/7044411/files/Digital-Delay.zip?download=1 -O Digital-Delay.zip
10
+ wget https://zenodo.org/record/7044411/files/Flanger.zip?download=1 -O Flanger.zip
11
+ wget https://zenodo.org/record/7044411/files/Hall-Reverb.zip?download=1 -O Hall-Reverb.zip
12
+ wget https://zenodo.org/record/7044411/files/Phaser.zip?download=1 -O Phaser.zip
13
+ wget https://zenodo.org/record/7044411/files/Plate-Reverb.zip?download=1 -O Plate-Reverb.zip
14
+ wget https://zenodo.org/record/7044411/files/RAT.zip?download=1 -O RAT.zip
15
+ wget https://zenodo.org/record/7044411/files/Spring-Reverb.zip?download=1 -O Spring-Reverb.zip
16
+ wget https://zenodo.org/record/7044411/files/Sweep-Echo.zip?download=1 -O Sweep-Echo.zip
17
+ wget https://zenodo.org/record/7044411/files/TapeEcho.zip?download=1 -O TapeEcho.zip
18
+ wget https://zenodo.org/record/7044411/files/TubeScreamer.zip?download=1 -O TubeScreamer.zip
19
+ unzip \*.zip
20
+
21
+
egfx.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
guitar_generation_test.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
models.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from audio_diffusion_pytorch import AudioDiffusionModel
2
+ import torch
3
+ from torch import Tensor
4
+ import pytorch_lightning as pl
5
+ from einops import rearrange
6
+ import wandb
7
+
8
+ SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
9
+
10
+
11
+ class TCNWrapper(pl.LightningModule):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.model = AudioDiffusionModel(in_channels=1)
15
+
16
+ def forward(self, x: torch.Tensor):
17
+ return self.model(x)
18
+
19
+ def training_step(self, batch, batch_idx):
20
+ loss = self.common_step(batch, batch_idx, mode="train")
21
+ return loss
22
+
23
+ def validation_step(self, batch, batch_idx):
24
+ loss = self.common_step(batch, batch_idx, mode="val")
25
+
26
+ def common_step(self, batch, batch_idx, mode: str = "train"):
27
+ x, target, label = batch
28
+ loss = self(x)
29
+ self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
30
+ return loss
31
+
32
+ def configure_optimizers(self):
33
+ return torch.optim.Adam(
34
+ self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
35
+ )
36
+
37
+
38
+ class AudioDiffusionWrapper(pl.LightningModule):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.model = AudioDiffusionModel(in_channels=1)
42
+
43
+ def forward(self, x: torch.Tensor):
44
+ return self.model(x)
45
+
46
+ def sample(self, *args, **kwargs) -> Tensor:
47
+ return self.model.sample(*args, **kwargs)
48
+
49
+ def training_step(self, batch, batch_idx):
50
+ loss = self.common_step(batch, batch_idx, mode="train")
51
+ return loss
52
+
53
+ def validation_step(self, batch, batch_idx):
54
+ loss = self.common_step(batch, batch_idx, mode="val")
55
+
56
+ def common_step(self, batch, batch_idx, mode: str = "train"):
57
+ x, target, label = batch
58
+ loss = self(x)
59
+ self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
60
+ return loss
61
+
62
+ def configure_optimizers(self):
63
+ return torch.optim.Adam(
64
+ self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
65
+ )
66
+
67
+ def on_validation_epoch_start(self):
68
+ self.log_next = True
69
+
70
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
71
+ x, target, label = batch
72
+ if self.log_next:
73
+ self.log_sample(x)
74
+ self.log_next = False
75
+
76
+ @torch.no_grad()
77
+ def log_sample(self, batch, num_steps=10):
78
+ # Get start diffusion noise
79
+ noise = torch.randn(batch.shape, device=self.device)
80
+ sampled = self.model.sample(
81
+ noise=noise, num_steps=num_steps # Suggested range: 2-50
82
+ )
83
+ self.log_wandb_audio_batch(
84
+ id="sample",
85
+ samples=sampled,
86
+ sampling_rate=SAMPLE_RATE,
87
+ caption=f"Sampled in {num_steps} steps",
88
+ )
89
+
90
+
91
+ def log_wandb_audio_batch(
92
+ id: str, samples: Tensor, sampling_rate: int, caption: str = ""
93
+ ):
94
+ num_items = samples.shape[0]
95
+ samples = rearrange(samples, "b c t -> b t c")
96
+ for idx in range(num_items):
97
+ wandb.log(
98
+ {
99
+ f"sample_{idx}_{id}": wandb.Audio(
100
+ samples[idx].cpu().numpy(),
101
+ caption=caption,
102
+ sample_rate=sampling_rate,
103
+ )
104
+ }
105
+ )
train.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_lightning.loggers import WandbLogger
2
+ import pytorch_lightning as pl
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+ from datasets import GuitarFXDataset
6
+ from models import AudioDiffusionWrapper
7
+
8
+ SAMPLE_RATE = 22050
9
+ TRAIN_SPLIT = 0.8
10
+
11
+
12
+ def main():
13
+ # wandb_logger = WandbLogger(project="RemFX", save_dir="./")
14
+ trainer = pl.Trainer() # logger=wandb_logger)
15
+ guitfx = GuitarFXDataset(
16
+ root="/Users/matthewrice/mir_datasets/egfxset",
17
+ sample_rate=SAMPLE_RATE,
18
+ effect_type=["Phaser"],
19
+ )
20
+ train_size = int(TRAIN_SPLIT * len(guitfx))
21
+ val_size = len(guitfx) - train_size
22
+ train_dataset, val_dataset = torch.utils.data.random_split(
23
+ guitfx, [train_size, val_size]
24
+ )
25
+ train = DataLoader(train_dataset, batch_size=2)
26
+ val = DataLoader(val_dataset, batch_size=2)
27
+ model = AudioDiffusionWrapper()
28
+ trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ main()