mattricesound commited on
Commit
8949a8c
1 Parent(s): 14ae0ea

Initial ptl model and training script for umx

Browse files
Files changed (5) hide show
  1. .gitignore +3 -1
  2. .gitmodules +3 -0
  3. models.py +97 -16
  4. train.py +9 -5
  5. umx +1 -0
.gitignore CHANGED
@@ -4,4 +4,6 @@ wandb/
4
  *.egg-info/
5
  data/
6
  .DS_Store
7
- __pycache__/
 
 
 
4
  *.egg-info/
5
  data/
6
  .DS_Store
7
+ __pycache__/
8
+ lightning_logs/
9
+ RemFX/
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "umx"]
2
+ path = umx
3
+ url = https://github.com/sigsep/open-unmix-pytorch
models.py CHANGED
@@ -1,44 +1,103 @@
1
- from audio_diffusion_pytorch import AudioDiffusionModel
2
  import torch
3
  from torch import Tensor
4
  import pytorch_lightning as pl
5
  from einops import rearrange
6
  import wandb
 
 
 
 
 
 
 
7
 
8
  SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
9
 
10
 
11
- class TCNWrapper(pl.LightningModule):
12
- def __init__(self):
 
 
 
 
 
13
  super().__init__()
14
- self.model = AudioDiffusionModel(in_channels=1)
 
 
 
 
 
 
 
 
15
 
16
  def forward(self, x: torch.Tensor):
17
  return self.model(x)
18
 
19
  def training_step(self, batch, batch_idx):
20
- loss = self.common_step(batch, batch_idx, mode="train")
21
  return loss
22
 
23
  def validation_step(self, batch, batch_idx):
24
- loss = self.common_step(batch, batch_idx, mode="val")
 
25
 
26
  def common_step(self, batch, batch_idx, mode: str = "train"):
27
  x, target, label = batch
28
- loss = self(x)
 
 
 
 
 
29
  self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
30
- return loss
31
 
32
  def configure_optimizers(self):
33
  return torch.optim.Adam(
34
  self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
35
  )
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- class AudioDiffusionWrapper(pl.LightningModule):
39
- def __init__(self):
40
  super().__init__()
41
- self.model = AudioDiffusionModel(in_channels=1)
42
 
43
  def forward(self, x: torch.Tensor):
44
  return self.model(x)
@@ -77,10 +136,8 @@ class AudioDiffusionWrapper(pl.LightningModule):
77
  def log_sample(self, batch, num_steps=10):
78
  # Get start diffusion noise
79
  noise = torch.randn(batch.shape, device=self.device)
80
- sampled = self.model.sample(
81
- noise=noise, num_steps=num_steps # Suggested range: 2-50
82
- )
83
- self.log_wandb_audio_batch(
84
  id="sample",
85
  samples=sampled,
86
  sampling_rate=SAMPLE_RATE,
@@ -96,10 +153,34 @@ def log_wandb_audio_batch(
96
  for idx in range(num_items):
97
  wandb.log(
98
  {
99
- f"sample_{idx}_{id}": wandb.Audio(
100
  samples[idx].cpu().numpy(),
101
  caption=caption,
102
  sample_rate=sampling_rate,
103
  )
104
  }
105
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from torch import Tensor
3
  import pytorch_lightning as pl
4
  from einops import rearrange
5
  import wandb
6
+ from audio_diffusion_pytorch import AudioDiffusionModel
7
+
8
+ import sys
9
+
10
+ sys.path.append("/Users/matthewrice/Developer/remfx/umx/")
11
+ from umx.openunmix.model import OpenUnmix, Separator
12
+
13
 
14
  SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
15
 
16
 
17
+ class OpenUnmixModel(pl.LightningModule):
18
+ def __init__(
19
+ self,
20
+ n_fft: int = 2048,
21
+ hop_length: int = 512,
22
+ alpha: float = 0.3,
23
+ ):
24
  super().__init__()
25
+ self.model = OpenUnmix(
26
+ nb_channels=1,
27
+ nb_bins=n_fft // 2 + 1,
28
+ )
29
+ self.n_fft = n_fft
30
+ self.hop_length = hop_length
31
+ self.alpha = alpha
32
+ window = torch.hann_window(n_fft)
33
+ self.register_buffer("window", window)
34
 
35
  def forward(self, x: torch.Tensor):
36
  return self.model(x)
37
 
38
  def training_step(self, batch, batch_idx):
39
+ loss, _ = self.common_step(batch, batch_idx, mode="train")
40
  return loss
41
 
42
  def validation_step(self, batch, batch_idx):
43
+ loss, Y = self.common_step(batch, batch_idx, mode="val")
44
+ return loss, Y
45
 
46
  def common_step(self, batch, batch_idx, mode: str = "train"):
47
  x, target, label = batch
48
+ X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
49
+ Y = self(X)
50
+ Y_hat = spectrogram(
51
+ target, self.window, self.n_fft, self.hop_length, self.alpha
52
+ )
53
+ loss = torch.nn.functional.mse_loss(Y, Y_hat)
54
  self.log(f"{mode}_loss", loss, on_step=True, on_epoch=True)
55
+ return loss, Y
56
 
57
  def configure_optimizers(self):
58
  return torch.optim.Adam(
59
  self.parameters(), lr=1e-4, betas=(0.95, 0.999), eps=1e-6, weight_decay=1e-3
60
  )
61
 
62
+ def on_validation_epoch_start(self):
63
+ self.log_next = True
64
+
65
+ def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
66
+ if self.log_next:
67
+ x, target, label = batch
68
+ s = Separator(
69
+ target_models={"other": self.model},
70
+ nb_channels=1,
71
+ sample_rate=SAMPLE_RATE,
72
+ n_fft=self.n_fft,
73
+ n_hop=self.hop_length,
74
+ )
75
+ outputs = s(x).squeeze(1)
76
+ log_wandb_audio_batch(
77
+ id="sample",
78
+ samples=x,
79
+ sampling_rate=SAMPLE_RATE,
80
+ caption=f"Epoch {self.current_epoch}",
81
+ )
82
+ log_wandb_audio_batch(
83
+ id="prediction",
84
+ samples=outputs,
85
+ sampling_rate=SAMPLE_RATE,
86
+ caption=f"Epoch {self.current_epoch}",
87
+ )
88
+ log_wandb_audio_batch(
89
+ id="target",
90
+ samples=target,
91
+ sampling_rate=SAMPLE_RATE,
92
+ caption=f"Epoch {self.current_epoch}",
93
+ )
94
+ self.log_next = False
95
+
96
 
97
+ class DiffusionGenerationModel(pl.LightningModule):
98
+ def __init__(self, model: torch.nn.Module):
99
  super().__init__()
100
+ self.model = model
101
 
102
  def forward(self, x: torch.Tensor):
103
  return self.model(x)
 
136
  def log_sample(self, batch, num_steps=10):
137
  # Get start diffusion noise
138
  noise = torch.randn(batch.shape, device=self.device)
139
+ sampled = self.sample(noise=noise, num_steps=num_steps) # Suggested range: 2-50
140
+ log_wandb_audio_batch(
 
 
141
  id="sample",
142
  samples=sampled,
143
  sampling_rate=SAMPLE_RATE,
 
153
  for idx in range(num_items):
154
  wandb.log(
155
  {
156
+ f"{id}_{idx}": wandb.Audio(
157
  samples[idx].cpu().numpy(),
158
  caption=caption,
159
  sample_rate=sampling_rate,
160
  )
161
  }
162
  )
163
+
164
+
165
+ def spectrogram(
166
+ x: torch.Tensor,
167
+ window: torch.Tensor,
168
+ n_fft: int,
169
+ hop_length: int,
170
+ alpha: float,
171
+ ) -> torch.Tensor:
172
+ bs, chs, samp = x.size()
173
+ x = x.view(bs * chs, -1) # move channels onto batch dim
174
+
175
+ X = torch.stft(
176
+ x,
177
+ n_fft=n_fft,
178
+ hop_length=hop_length,
179
+ window=window,
180
+ return_complex=True,
181
+ )
182
+
183
+ # move channels back
184
+ X = X.view(bs, chs, X.shape[-2], X.shape[-1])
185
+
186
+ return torch.pow(X.abs() + 1e-8, alpha)
train.py CHANGED
@@ -3,17 +3,18 @@ import pytorch_lightning as pl
3
  import torch
4
  from torch.utils.data import DataLoader
5
  from datasets import GuitarFXDataset
6
- from models import AudioDiffusionWrapper
 
7
 
8
  SAMPLE_RATE = 22050
9
  TRAIN_SPLIT = 0.8
10
 
11
 
12
  def main():
13
- # wandb_logger = WandbLogger(project="RemFX", save_dir="./")
14
- trainer = pl.Trainer() # logger=wandb_logger)
15
  guitfx = GuitarFXDataset(
16
- root="/Users/matthewrice/mir_datasets/egfxset",
17
  sample_rate=SAMPLE_RATE,
18
  effect_type=["Phaser"],
19
  )
@@ -24,7 +25,10 @@ def main():
24
  )
25
  train = DataLoader(train_dataset, batch_size=2)
26
  val = DataLoader(val_dataset, batch_size=2)
27
- model = AudioDiffusionWrapper()
 
 
 
28
  trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
29
 
30
 
 
3
  import torch
4
  from torch.utils.data import DataLoader
5
  from datasets import GuitarFXDataset
6
+ from models import DiffusionGenerationModel, OpenUnmixModel
7
+
8
 
9
  SAMPLE_RATE = 22050
10
  TRAIN_SPLIT = 0.8
11
 
12
 
13
  def main():
14
+ wandb_logger = WandbLogger(project="RemFX", save_dir="./")
15
+ trainer = pl.Trainer(logger=wandb_logger, max_epochs=10)
16
  guitfx = GuitarFXDataset(
17
+ root="/Users/matthewrice/Developer/remfx/data/egfx",
18
  sample_rate=SAMPLE_RATE,
19
  effect_type=["Phaser"],
20
  )
 
25
  )
26
  train = DataLoader(train_dataset, batch_size=2)
27
  val = DataLoader(val_dataset, batch_size=2)
28
+
29
+ # model = DiffusionGenerationModel()
30
+ model = OpenUnmixModel()
31
+
32
  trainer.fit(model=model, train_dataloaders=train, val_dataloaders=val)
33
 
34
 
umx ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 05fd4d8a0e3e50e308579052d762a342647c3408