mattricesound commited on
Commit
ccecb22
1 Parent(s): 647e1a1

Update file structure and remove os path dependency for umx. Increase default sr to 44.1kHz

Browse files
Files changed (7) hide show
  1. README.md +1 -0
  2. config.yaml +1 -1
  3. exp/audio_diffusion.yaml +3 -2
  4. exp/umx.yaml +3 -2
  5. models.py +0 -196
  6. train.py +2 -2
  7. utils.py +0 -71
README.md CHANGED
@@ -2,6 +2,7 @@
2
  ## Install Packages
3
  `python3 -m venv env`
4
  `pip install -e .`
 
5
 
6
  ## Download [GuitarFX Dataset] (https://zenodo.org/record/7044411/)
7
  `./download_egfx.sh`
 
2
  ## Install Packages
3
  `python3 -m venv env`
4
  `pip install -e .`
5
+ `pip install -e umx`
6
 
7
  ## Download [GuitarFX Dataset] (https://zenodo.org/record/7044411/)
8
  `./download_egfx.sh`
config.yaml CHANGED
@@ -4,7 +4,7 @@ defaults:
4
  seed: 12345
5
  train: True
6
  length: 262144
7
- sample_rate: 22050
8
  logs_dir: "./logs"
9
  log_every_n_steps: 1000
10
 
 
4
  seed: 12345
5
  train: True
6
  length: 262144
7
+ sample_rate: 48000
8
  logs_dir: "./logs"
9
  log_every_n_steps: 1000
10
 
exp/audio_diffusion.yaml CHANGED
@@ -1,13 +1,14 @@
1
  # @package _global_
2
  model:
3
- _target_: models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
7
  lr_eps: 1e-6
8
  lr_weight_decay: 1e-3
 
9
  network:
10
- _target_: models.DiffusionGenerationModel
11
  n_channels: 1
12
  datamodule:
13
  dataset:
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
7
  lr_eps: 1e-6
8
  lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
  network:
11
+ _target_: remfx.models.DiffusionGenerationModel
12
  n_channels: 1
13
  datamodule:
14
  dataset:
exp/umx.yaml CHANGED
@@ -1,13 +1,14 @@
1
  # @package _global_
2
  model:
3
- _target_: models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
7
  lr_eps: 1e-6
8
  lr_weight_decay: 1e-3
 
9
  network:
10
- _target_: models.OpenUnmixModel
11
  n_fft: 2048
12
  hop_length: 512
13
  n_channels: 1
 
1
  # @package _global_
2
  model:
3
+ _target_: remfx.models.RemFXModel
4
  lr: 1e-4
5
  lr_beta1: 0.95
6
  lr_beta2: 0.999
7
  lr_eps: 1e-6
8
  lr_weight_decay: 1e-3
9
+ sample_rate: ${sample_rate}
10
  network:
11
+ _target_: remfx.models.OpenUnmixModel
12
  n_fft: 2048
13
  hop_length: 512
14
  n_channels: 1
models.py DELETED
@@ -1,196 +0,0 @@
1
- import torch
2
- from torch import Tensor, nn
3
- import pytorch_lightning as pl
4
- from einops import rearrange
5
- import wandb
6
- from audio_diffusion_pytorch import AudioDiffusionModel
7
- import auraloss
8
-
9
- import sys
10
-
11
- sys.path.append("./umx")
12
- from umx.openunmix.model import OpenUnmix, Separator
13
-
14
-
15
- SAMPLE_RATE = 22050 # From audio-diffusion-pytorch
16
-
17
-
18
- class RemFXModel(pl.LightningModule):
19
- def __init__(
20
- self,
21
- lr: float,
22
- lr_beta1: float,
23
- lr_beta2: float,
24
- lr_eps: float,
25
- lr_weight_decay: float,
26
- network: nn.Module,
27
- ):
28
- super().__init__()
29
- self.lr = lr
30
- self.lr_beta1 = lr_beta1
31
- self.lr_beta2 = lr_beta2
32
- self.lr_eps = lr_eps
33
- self.lr_weight_decay = lr_weight_decay
34
- self.model = network
35
-
36
- @property
37
- def device(self):
38
- return next(self.model.parameters()).device
39
-
40
- def configure_optimizers(self):
41
- optimizer = torch.optim.AdamW(
42
- list(self.model.parameters()),
43
- lr=self.lr,
44
- betas=(self.lr_beta1, self.lr_beta2),
45
- eps=self.lr_eps,
46
- weight_decay=self.lr_weight_decay,
47
- )
48
- return optimizer
49
-
50
- def training_step(self, batch, batch_idx):
51
- loss = self.common_step(batch, batch_idx, mode="train")
52
- return loss
53
-
54
- def validation_step(self, batch, batch_idx):
55
- loss = self.common_step(batch, batch_idx, mode="valid")
56
-
57
- def common_step(self, batch, batch_idx, mode: str = "train"):
58
- loss = self.model(batch)
59
- self.log(f"{mode}_loss", loss)
60
- return loss
61
-
62
- def on_validation_epoch_start(self):
63
- self.log_next = True
64
-
65
- def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
66
- if self.log_next:
67
- x, target, label = batch
68
- y = self.model.sample(x)
69
- log_wandb_audio_batch(
70
- logger=self.logger,
71
- id="sample",
72
- samples=x.cpu(),
73
- sampling_rate=SAMPLE_RATE,
74
- caption=f"Epoch {self.current_epoch}",
75
- )
76
- log_wandb_audio_batch(
77
- logger=self.logger,
78
- id="prediction",
79
- samples=y.cpu(),
80
- sampling_rate=SAMPLE_RATE,
81
- caption=f"Epoch {self.current_epoch}",
82
- )
83
- log_wandb_audio_batch(
84
- logger=self.logger,
85
- id="target",
86
- samples=target.cpu(),
87
- sampling_rate=SAMPLE_RATE,
88
- caption=f"Epoch {self.current_epoch}",
89
- )
90
- self.log_next = False
91
-
92
-
93
- class OpenUnmixModel(torch.nn.Module):
94
- def __init__(
95
- self,
96
- n_fft: int = 2048,
97
- hop_length: int = 512,
98
- n_channels: int = 1,
99
- alpha: float = 0.3,
100
- sample_rate: int = 22050,
101
- ):
102
- super().__init__()
103
- self.n_channels = n_channels
104
- self.n_fft = n_fft
105
- self.hop_length = hop_length
106
- self.alpha = alpha
107
- window = torch.hann_window(n_fft)
108
- self.register_buffer("window", window)
109
-
110
- self.num_bins = self.n_fft // 2 + 1
111
- self.sample_rate = sample_rate
112
- self.model = OpenUnmix(
113
- nb_channels=self.n_channels,
114
- nb_bins=self.num_bins,
115
- )
116
- self.separator = Separator(
117
- target_models={"other": self.model},
118
- nb_channels=self.n_channels,
119
- sample_rate=self.sample_rate,
120
- n_fft=self.n_fft,
121
- n_hop=self.hop_length,
122
- )
123
- self.loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
124
- n_bins=self.num_bins, sample_rate=self.sample_rate
125
- )
126
-
127
- def forward(self, batch):
128
- x, target, label = batch
129
- X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
130
- Y = self.model(X)
131
- sep_out = self.separator(x).squeeze(1)
132
- loss = self.loss_fn(sep_out, target)
133
-
134
- return loss
135
-
136
- def sample(self, x: Tensor) -> Tensor:
137
- return self.separator(x).squeeze(1)
138
-
139
-
140
- class DiffusionGenerationModel(nn.Module):
141
- def __init__(self, n_channels: int = 1):
142
- super().__init__()
143
- self.model = AudioDiffusionModel(in_channels=n_channels)
144
-
145
- def forward(self, batch):
146
- x, target, label = batch
147
- return self.model(x)
148
-
149
- def sample(self, x: Tensor, num_steps: int = 10) -> Tensor:
150
- noise = torch.randn(x.shape).to(x)
151
- return self.model.sample(noise, num_steps=num_steps)
152
-
153
-
154
- def log_wandb_audio_batch(
155
- logger: pl.loggers.WandbLogger,
156
- id: str,
157
- samples: Tensor,
158
- sampling_rate: int,
159
- caption: str = "",
160
- ):
161
- num_items = samples.shape[0]
162
- samples = rearrange(samples, "b c t -> b t c")
163
- for idx in range(num_items):
164
- logger.experiment.log(
165
- {
166
- f"{id}_{idx}": wandb.Audio(
167
- samples[idx].cpu().numpy(),
168
- caption=caption,
169
- sample_rate=sampling_rate,
170
- )
171
- }
172
- )
173
-
174
-
175
- def spectrogram(
176
- x: torch.Tensor,
177
- window: torch.Tensor,
178
- n_fft: int,
179
- hop_length: int,
180
- alpha: float,
181
- ) -> torch.Tensor:
182
- bs, chs, samp = x.size()
183
- x = x.view(bs * chs, -1) # move channels onto batch dim
184
-
185
- X = torch.stft(
186
- x,
187
- n_fft=n_fft,
188
- hop_length=hop_length,
189
- window=window,
190
- return_complex=True,
191
- )
192
-
193
- # move channels back
194
- X = X.view(bs, chs, X.shape[-2], X.shape[-1])
195
-
196
- return torch.pow(X.abs() + 1e-8, alpha)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -2,10 +2,10 @@ from pytorch_lightning.loggers import WandbLogger
2
  import pytorch_lightning as pl
3
  from torch.utils.data import DataLoader
4
  from datasets import GuitarFXDataset
5
- from models import DiffusionGenerationModel, OpenUnmixModel
6
  import hydra
7
  from omegaconf import DictConfig
8
- import utils
9
 
10
  log = utils.get_logger(__name__)
11
 
 
2
  import pytorch_lightning as pl
3
  from torch.utils.data import DataLoader
4
  from datasets import GuitarFXDataset
5
+ from remfx.models import DiffusionGenerationModel, OpenUnmixModel
6
  import hydra
7
  from omegaconf import DictConfig
8
+ import remfx.utils as utils
9
 
10
  log = utils.get_logger(__name__)
11
 
utils.py DELETED
@@ -1,71 +0,0 @@
1
- import logging
2
- from typing import List
3
- import pytorch_lightning as pl
4
- from omegaconf import DictConfig
5
- from pytorch_lightning.utilities import rank_zero_only
6
-
7
-
8
- def get_logger(name=__name__) -> logging.Logger:
9
- """Initializes multi-GPU-friendly python command line logger."""
10
-
11
- logger = logging.getLogger(name)
12
-
13
- # this ensures all logging levels get marked with the rank zero decorator
14
- # otherwise logs would get multiplied for each GPU process in multi-GPU setup
15
- for level in (
16
- "debug",
17
- "info",
18
- "warning",
19
- "error",
20
- "exception",
21
- "fatal",
22
- "critical",
23
- ):
24
- setattr(logger, level, rank_zero_only(getattr(logger, level)))
25
-
26
- return logger
27
-
28
-
29
- log = get_logger(__name__)
30
-
31
-
32
- @rank_zero_only
33
- def log_hyperparameters(
34
- config: DictConfig,
35
- model: pl.LightningModule,
36
- datamodule: pl.LightningDataModule,
37
- trainer: pl.Trainer,
38
- callbacks: List[pl.Callback],
39
- logger: pl.loggers.LightningLoggerBase,
40
- ) -> None:
41
- """Controls which config parts are saved by Lightning loggers.
42
- Additionaly saves:
43
- - number of model parameters
44
- """
45
-
46
- if not trainer.logger:
47
- return
48
-
49
- hparams = {}
50
-
51
- # choose which parts of hydra config will be saved to loggers
52
- hparams["model"] = config["model"]
53
-
54
- # save number of model parameters
55
- hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
56
- hparams["model/params/trainable"] = sum(
57
- p.numel() for p in model.parameters() if p.requires_grad
58
- )
59
- hparams["model/params/non_trainable"] = sum(
60
- p.numel() for p in model.parameters() if not p.requires_grad
61
- )
62
-
63
- hparams["datamodule"] = config["datamodule"]
64
- hparams["trainer"] = config["trainer"]
65
-
66
- if "seed" in config:
67
- hparams["seed"] = config["seed"]
68
- if "callbacks" in config:
69
- hparams["callbacks"] = config["callbacks"]
70
-
71
- logger.experiment.config.update(hparams)