mattricesound commited on
Commit
d8d3e30
1 Parent(s): 4a7a6b8

Add gradient clipping and lr scheduler

Browse files
Files changed (5) hide show
  1. cfg/config.yaml +5 -1
  2. remfx/datasets.py +4 -7
  3. remfx/models.py +22 -0
  4. remfx/utils.py +3 -4
  5. scripts/train.py +1 -0
cfg/config.yaml CHANGED
@@ -8,7 +8,6 @@ train: True
8
  sample_rate: 48000
9
  chunk_size: 262144 # 5.5s
10
  logs_dir: "./logs"
11
- log_every_n_steps: 1000
12
  render_files: True
13
  render_root: "./data/processed"
14
 
@@ -22,6 +21,9 @@ callbacks:
22
  verbose: False
23
  dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
24
  filename: '{epoch:02d}-{valid_loss:.3f}'
 
 
 
25
 
26
  datamodule:
27
  _target_: remfx.datasets.VocalSetDatamodule
@@ -77,3 +79,5 @@ trainer:
77
  accumulate_grad_batches: 1
78
  accelerator: null
79
  devices: 1
 
 
 
8
  sample_rate: 48000
9
  chunk_size: 262144 # 5.5s
10
  logs_dir: "./logs"
 
11
  render_files: True
12
  render_root: "./data/processed"
13
 
 
21
  verbose: False
22
  dirpath: ${logs_dir}/ckpts/${now:%Y-%m-%d-%H-%M-%S}
23
  filename: '{epoch:02d}-{valid_loss:.3f}'
24
+ learning_rate_monitor:
25
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
26
+ logging_interval: "step"
27
 
28
  datamodule:
29
  _target_: remfx.datasets.VocalSetDatamodule
 
79
  accumulate_grad_batches: 1
80
  accelerator: null
81
  devices: 1
82
+ gradient_clip_val: 10.0
83
+ max_steps: 50000
remfx/datasets.py CHANGED
@@ -17,7 +17,7 @@ class VocalSet(Dataset):
17
  self,
18
  root: str,
19
  sample_rate: int,
20
- chunk_size_in_sec: int = 3,
21
  effect_types: List[torch.nn.Module] = None,
22
  render_files: bool = True,
23
  render_root: str = None,
@@ -28,7 +28,7 @@ class VocalSet(Dataset):
28
  self.song_idx = []
29
  self.root = Path(root)
30
  self.render_root = Path(render_root)
31
- self.chunk_size_in_sec = chunk_size_in_sec
32
  self.sample_rate = sample_rate
33
  self.mode = mode
34
 
@@ -46,15 +46,12 @@ class VocalSet(Dataset):
46
  # Split audio file into chunks, resample, then apply random effects
47
  self.processed_root.mkdir(parents=True, exist_ok=True)
48
  for audio_file in tqdm(self.files, total=len(self.files)):
49
- chunks, orig_sr = create_sequential_chunks(
50
- audio_file, self.chunk_size_in_sec
51
- )
52
  for chunk in chunks:
53
  resampled_chunk = torchaudio.functional.resample(
54
  chunk, orig_sr, sample_rate
55
  )
56
- chunk_size_in_samples = self.chunk_size_in_sec * self.sample_rate
57
- if resampled_chunk.shape[-1] < chunk_size_in_samples:
58
  # Skip if chunk is too small
59
  continue
60
  # Apply effect
 
17
  self,
18
  root: str,
19
  sample_rate: int,
20
+ chunk_size: int = 3,
21
  effect_types: List[torch.nn.Module] = None,
22
  render_files: bool = True,
23
  render_root: str = None,
 
28
  self.song_idx = []
29
  self.root = Path(root)
30
  self.render_root = Path(render_root)
31
+ self.chunk_size = chunk_size
32
  self.sample_rate = sample_rate
33
  self.mode = mode
34
 
 
46
  # Split audio file into chunks, resample, then apply random effects
47
  self.processed_root.mkdir(parents=True, exist_ok=True)
48
  for audio_file in tqdm(self.files, total=len(self.files)):
49
+ chunks, orig_sr = create_sequential_chunks(audio_file, self.chunk_size)
 
 
50
  for chunk in chunks:
51
  resampled_chunk = torchaudio.functional.resample(
52
  chunk, orig_sr, sample_rate
53
  )
54
+ if resampled_chunk.shape[-1] < chunk_size:
 
55
  # Skip if chunk is too small
56
  continue
57
  # Apply effect
remfx/models.py CHANGED
@@ -55,6 +55,28 @@ class RemFXModel(pl.LightningModule):
55
  )
56
  return optimizer
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def training_step(self, batch, batch_idx):
59
  loss = self.common_step(batch, batch_idx, mode="train")
60
  return loss
 
55
  )
56
  return optimizer
57
 
58
+ # Add step-based learning rate scheduler
59
+ def optimizer_step(
60
+ self,
61
+ epoch,
62
+ batch_idx,
63
+ optimizer,
64
+ optimizer_idx,
65
+ optimizer_closure,
66
+ on_tpu=False,
67
+ using_lbfgs=False,
68
+ ):
69
+ # update params
70
+ optimizer.step(closure=optimizer_closure)
71
+
72
+ # update learning rate. Reduce by factor of 10 at 80% and 95% of training
73
+ if self.trainer.global_step == 0.8 * self.trainer.max_steps:
74
+ for pg in optimizer.param_groups:
75
+ pg["lr"] = 0.1 * pg["lr"]
76
+ if self.trainer.global_step == 0.95 * self.trainer.max_steps:
77
+ for pg in optimizer.param_groups:
78
+ pg["lr"] = 0.1 * pg["lr"]
79
+
80
  def training_step(self, batch, batch_idx):
81
  loss = self.common_step(batch, batch_idx, mode="train")
82
  return loss
remfx/utils.py CHANGED
@@ -132,10 +132,9 @@ def create_sequential_chunks(
132
  """
133
  chunks = []
134
  audio, sr = torchaudio.load(audio_file)
135
- chunk_size_in_samples = chunk_size * sr
136
- chunk_starts = torch.arange(0, audio.shape[-1], chunk_size_in_samples)
137
  for start in chunk_starts:
138
- if start + chunk_size_in_samples > audio.shape[-1]:
139
  break
140
- chunks.append(audio[:, start : start + chunk_size_in_samples])
141
  return chunks, sr
 
132
  """
133
  chunks = []
134
  audio, sr = torchaudio.load(audio_file)
135
+ chunk_starts = torch.arange(0, audio.shape[-1], chunk_size)
 
136
  for start in chunk_starts:
137
+ if start + chunk_size > audio.shape[-1]:
138
  break
139
+ chunks.append(audio[:, start : start + chunk_size])
140
  return chunks, sr
scripts/train.py CHANGED
@@ -42,6 +42,7 @@ def main(cfg: DictConfig):
42
  summary = ModelSummary(model)
43
  print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
 
45
 
46
 
47
  if __name__ == "__main__":
 
42
  summary = ModelSummary(model)
43
  print(summary)
44
  trainer.fit(model=model, datamodule=datamodule)
45
+ trainer.test(model=model, datamodule=datamodule)
46
 
47
 
48
  if __name__ == "__main__":