mattricesound commited on
Commit
0fbacb2
1 Parent(s): f3350b1

Add flag to not log audio

Browse files
Files changed (3) hide show
  1. cfg/config.yaml +2 -0
  2. cfg/exp/default.yaml +2 -0
  3. remfx/callbacks.py +5 -2
cfg/config.yaml CHANGED
@@ -11,6 +11,7 @@ logs_dir: "./logs"
11
  render_files: True
12
  render_root: "./data"
13
  accelerator: null
 
14
 
15
  # Effects
16
  max_kept_effects: -1
@@ -47,6 +48,7 @@ callbacks:
47
  audio_logging:
48
  _target_: remfx.callbacks.AudioCallback
49
  sample_rate: ${sample_rate}
 
50
  metric_logging:
51
  _target_: remfx.callbacks.MetricCallback
52
 
 
11
  render_files: True
12
  render_root: "./data"
13
  accelerator: null
14
+ log_audio: True
15
 
16
  # Effects
17
  max_kept_effects: -1
 
48
  audio_logging:
49
  _target_: remfx.callbacks.AudioCallback
50
  sample_rate: ${sample_rate}
51
+ log_audio: ${log_audio}
52
  metric_logging:
53
  _target_: remfx.callbacks.MetricCallback
54
 
cfg/exp/default.yaml CHANGED
@@ -9,6 +9,8 @@ logs_dir: "./logs"
9
  render_files: True
10
  render_root: "./data"
11
  accelerator: null
 
 
12
  max_kept_effects: -1
13
  max_removed_effects: -1
14
  shuffle_kept_effects: True
 
9
  render_files: True
10
  render_root: "./data"
11
  accelerator: null
12
+ log_audio: True
13
+ # Effects
14
  max_kept_effects: -1
15
  max_removed_effects: -1
16
  shuffle_kept_effects: True
remfx/callbacks.py CHANGED
@@ -7,10 +7,13 @@ from torch import Tensor
7
 
8
 
9
  class AudioCallback(Callback):
10
- def __init__(self, sample_rate, *args, **kwargs):
11
  super().__init__(*args, **kwargs)
 
12
  self.log_train_audio = True
13
  self.sample_rate = sample_rate
 
 
14
 
15
  def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
16
  # Log initial audio
@@ -41,7 +44,7 @@ class AudioCallback(Callback):
41
  ):
42
  x, target, _, _ = batch
43
  # Only run on first batch
44
- if batch_idx == 0:
45
  with torch.no_grad():
46
  y = pl_module.model.sample(x)
47
  # Concat samples together for easier viewing in dashboard
 
7
 
8
 
9
  class AudioCallback(Callback):
10
+ def __init__(self, sample_rate, log_audio, *args, **kwargs):
11
  super().__init__(*args, **kwargs)
12
+ self.log_audio = log_audio
13
  self.log_train_audio = True
14
  self.sample_rate = sample_rate
15
+ if not self.log_audio:
16
+ self.log_train_audio = False
17
 
18
  def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
19
  # Log initial audio
 
44
  ):
45
  x, target, _, _ = batch
46
  # Only run on first batch
47
+ if batch_idx == 0 and self.log_audio:
48
  with torch.no_grad():
49
  y = pl_module.model.sample(x)
50
  # Concat samples together for easier viewing in dashboard