Spaces:
Runtime error
Runtime error
mattricesound
commited on
Commit
•
0fbacb2
1
Parent(s):
f3350b1
Add flag to not log audio
Browse files- cfg/config.yaml +2 -0
- cfg/exp/default.yaml +2 -0
- remfx/callbacks.py +5 -2
cfg/config.yaml
CHANGED
@@ -11,6 +11,7 @@ logs_dir: "./logs"
|
|
11 |
render_files: True
|
12 |
render_root: "./data"
|
13 |
accelerator: null
|
|
|
14 |
|
15 |
# Effects
|
16 |
max_kept_effects: -1
|
@@ -47,6 +48,7 @@ callbacks:
|
|
47 |
audio_logging:
|
48 |
_target_: remfx.callbacks.AudioCallback
|
49 |
sample_rate: ${sample_rate}
|
|
|
50 |
metric_logging:
|
51 |
_target_: remfx.callbacks.MetricCallback
|
52 |
|
|
|
11 |
render_files: True
|
12 |
render_root: "./data"
|
13 |
accelerator: null
|
14 |
+
log_audio: True
|
15 |
|
16 |
# Effects
|
17 |
max_kept_effects: -1
|
|
|
48 |
audio_logging:
|
49 |
_target_: remfx.callbacks.AudioCallback
|
50 |
sample_rate: ${sample_rate}
|
51 |
+
log_audio: ${log_audio}
|
52 |
metric_logging:
|
53 |
_target_: remfx.callbacks.MetricCallback
|
54 |
|
cfg/exp/default.yaml
CHANGED
@@ -9,6 +9,8 @@ logs_dir: "./logs"
|
|
9 |
render_files: True
|
10 |
render_root: "./data"
|
11 |
accelerator: null
|
|
|
|
|
12 |
max_kept_effects: -1
|
13 |
max_removed_effects: -1
|
14 |
shuffle_kept_effects: True
|
|
|
9 |
render_files: True
|
10 |
render_root: "./data"
|
11 |
accelerator: null
|
12 |
+
log_audio: True
|
13 |
+
# Effects
|
14 |
max_kept_effects: -1
|
15 |
max_removed_effects: -1
|
16 |
shuffle_kept_effects: True
|
remfx/callbacks.py
CHANGED
@@ -7,10 +7,13 @@ from torch import Tensor
|
|
7 |
|
8 |
|
9 |
class AudioCallback(Callback):
|
10 |
-
def __init__(self, sample_rate, *args, **kwargs):
|
11 |
super().__init__(*args, **kwargs)
|
|
|
12 |
self.log_train_audio = True
|
13 |
self.sample_rate = sample_rate
|
|
|
|
|
14 |
|
15 |
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
16 |
# Log initial audio
|
@@ -41,7 +44,7 @@ class AudioCallback(Callback):
|
|
41 |
):
|
42 |
x, target, _, _ = batch
|
43 |
# Only run on first batch
|
44 |
-
if batch_idx == 0:
|
45 |
with torch.no_grad():
|
46 |
y = pl_module.model.sample(x)
|
47 |
# Concat samples together for easier viewing in dashboard
|
|
|
7 |
|
8 |
|
9 |
class AudioCallback(Callback):
|
10 |
+
def __init__(self, sample_rate, log_audio, *args, **kwargs):
|
11 |
super().__init__(*args, **kwargs)
|
12 |
+
self.log_audio = log_audio
|
13 |
self.log_train_audio = True
|
14 |
self.sample_rate = sample_rate
|
15 |
+
if not self.log_audio:
|
16 |
+
self.log_train_audio = False
|
17 |
|
18 |
def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):
|
19 |
# Log initial audio
|
|
|
44 |
):
|
45 |
x, target, _, _ = batch
|
46 |
# Only run on first batch
|
47 |
+
if batch_idx == 0 and self.log_audio:
|
48 |
with torch.no_grad():
|
49 |
y = pl_module.model.sample(x)
|
50 |
# Concat samples together for easier viewing in dashboard
|