unpairedelectron07 commited on
Commit
0a79d52
1 Parent(s): f771d8e

Upload 2 files

Browse files
audiocraft/grids/diffusion/4_bands_base_32khz.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Training of the 4 diffusion models described in
9
+ "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10
+ (paper link).
11
+ """
12
+
13
+ from ._explorers import DiffusionExplorer
14
+
15
+
16
+ @DiffusionExplorer
17
+ def explorer(launcher):
18
+ launcher.slurm_(gpus=4, partition='learnfair')
19
+
20
+ launcher.bind_({'solver': 'diffusion/default',
21
+ 'dset': 'internal/music_10k_32khz'})
22
+
23
+ with launcher.job_array():
24
+ launcher({'filter.use': True, 'filter.idx_band': 0, "processor.use": False, 'processor.power_std': 0.4})
25
+ launcher({'filter.use': True, 'filter.idx_band': 1, "processor.use": False, 'processor.power_std': 0.4})
26
+ launcher({'filter.use': True, 'filter.idx_band': 2, "processor.use": True, 'processor.power_std': 0.4})
27
+ launcher({'filter.use': True, 'filter.idx_band': 3, "processor.use": True, 'processor.power_std': 0.75})
audiocraft/grids/diffusion/_explorers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import treetable as tt
8
+
9
+ from .._base_explorers import BaseExplorer
10
+
11
+
12
+ class DiffusionExplorer(BaseExplorer):
13
+ eval_metrics = ["sisnr", "visqol"]
14
+
15
+ def stages(self):
16
+ return ["train", "valid", "valid_ema", "evaluate", "evaluate_ema"]
17
+
18
+ def get_grid_meta(self):
19
+ """Returns the list of Meta information to display for each XP/job.
20
+ """
21
+ return [
22
+ tt.leaf("index", align=">"),
23
+ tt.leaf("name", wrap=140),
24
+ tt.leaf("state"),
25
+ tt.leaf("sig", align=">"),
26
+ ]
27
+
28
+ def get_grid_metrics(self):
29
+ """Return the metrics that should be displayed in the tracking table.
30
+ """
31
+ return [
32
+ tt.group(
33
+ "train",
34
+ [
35
+ tt.leaf("epoch"),
36
+ tt.leaf("loss", ".3%"),
37
+ ],
38
+ align=">",
39
+ ),
40
+ tt.group(
41
+ "valid",
42
+ [
43
+ tt.leaf("loss", ".3%"),
44
+ # tt.leaf("loss_0", ".3%"),
45
+ ],
46
+ align=">",
47
+ ),
48
+ tt.group(
49
+ "valid_ema",
50
+ [
51
+ tt.leaf("loss", ".3%"),
52
+ # tt.leaf("loss_0", ".3%"),
53
+ ],
54
+ align=">",
55
+ ),
56
+ tt.group(
57
+ "evaluate", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
58
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
59
+ tt.leaf("rvm_3", ".4f"), ], align=">"
60
+ ),
61
+ tt.group(
62
+ "evaluate_ema", [tt.leaf("rvm", ".4f"), tt.leaf("rvm_0", ".4f"),
63
+ tt.leaf("rvm_1", ".4f"), tt.leaf("rvm_2", ".4f"),
64
+ tt.leaf("rvm_3", ".4f")], align=">"
65
+ ),
66
+ ]