unpairedelectron07 commited on
Commit
e3061ad
1 Parent(s): 699b46d

Upload 7 files

Browse files
audiocraft/solvers/audiogen.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from . import builders, musicgen
8
+
9
+
10
+ class AudioGenSolver(musicgen.MusicGenSolver):
11
+ """Solver for AudioGen re-implementation training task.
12
+
13
+ Note that this implementation does not strictly follows
14
+ the method proposed in https://arxiv.org/abs/2209.15352
15
+ but is derived from MusicGen's training pipeline.
16
+
17
+ More information can be found in the AudioGen model card.
18
+ """
19
+ DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
audiocraft/solvers/base.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from abc import ABC, abstractmethod
8
+ from contextlib import contextmanager
9
+ from pathlib import Path
10
+ import typing as tp
11
+
12
+ import flashy
13
+ import omegaconf
14
+ import torch
15
+ from torch import nn
16
+
17
+ from .. import optim
18
+ from ..optim import fsdp
19
+ from ..utils import checkpoint
20
+ from ..utils.autocast import TorchAutocast
21
+ from ..utils.best_state import BestStateDictManager
22
+ from ..utils.deadlock import DeadlockDetect
23
+ from ..utils.profiler import Profiler
24
+ from ..utils.utils import copy_state, dict_from_config, model_hash, with_rank_rng
25
+
26
+
27
+ class StandardSolver(ABC, flashy.BaseSolver):
28
+ """Standard solver for AudioCraft.
29
+
30
+ The standard solver implements a base training loop with the following stages:
31
+ train, valid, evaluate and generate that are expected to be all defined for
32
+ solvers in AudioCraft. It also provides a nice default management of Dora history replay,
33
+ checkpoint management across epoch, and logging configuration.
34
+
35
+ AudioCraft solvers must inherit from the StandardSolver and define the methods
36
+ associated to each stage as well as the show, build_model and build_dataloaders methods.
37
+ """
38
+ def __init__(self, cfg: omegaconf.DictConfig):
39
+ super().__init__()
40
+ self.logger.info(f"Instantiating solver {self.__class__.__name__} for XP {self.xp.sig}")
41
+ self.logger.info(f"All XP logs are stored in {self.xp.folder}")
42
+ self.cfg = cfg
43
+ self.device = cfg.device
44
+ self.model: nn.Module
45
+ self._continue_best_source_keys = ['best_state', 'fsdp_best_state']
46
+ self._fsdp_modules: tp.List[fsdp.FSDP] = []
47
+ self._ema_sources: nn.ModuleDict = nn.ModuleDict()
48
+ self.ema: tp.Optional[optim.ModuleDictEMA] = None
49
+ self.dataloaders: tp.Dict[str, torch.utils.data.DataLoader] = dict()
50
+ self._log_updates = self.cfg.logging.get('log_updates', 10)
51
+ if self.cfg.logging.log_tensorboard:
52
+ self.init_tensorboard(**self.cfg.get('tensorboard'))
53
+ if self.cfg.logging.log_wandb and self:
54
+ self.init_wandb(**self.cfg.get('wandb'))
55
+ # keep a copy of the best performing state for stateful objects
56
+ # used for evaluation and generation stages
57
+ dtype_best: tp.Optional[torch.dtype] = None
58
+ if self.cfg.fsdp.use:
59
+ dtype_best = getattr(torch, self.cfg.fsdp.param_dtype) # type: ignore
60
+ assert isinstance(dtype_best, torch.dtype)
61
+ elif self.cfg.autocast:
62
+ dtype_best = getattr(torch, self.cfg.autocast_dtype) # type: ignore
63
+ assert isinstance(dtype_best, torch.dtype)
64
+ self.best_state: BestStateDictManager = BestStateDictManager(dtype=dtype_best)
65
+ # Hacky support for keeping a copy of the full best state in rank0.
66
+ self.fsdp_best_state: tp.Dict[str, tp.Any] = {}
67
+ self.register_stateful('best_state', 'fsdp_best_state') # register best_state object to keep it in state_dict
68
+ self._new_best_state: bool = False # should save a new checkpoint
69
+ # instantiate datasets and appropriate number of updates per epoch
70
+ self.build_dataloaders()
71
+ if self.cfg.execute_only is None:
72
+ assert 'train' in self.dataloaders, "The train dataset split must be provided."
73
+ assert 'valid' in self.dataloaders, "The valid dataset split must be provided."
74
+ self.train_updates_per_epoch = len(self.dataloaders['train']) if 'train' in self.dataloaders else 0
75
+ if self.cfg.optim.updates_per_epoch:
76
+ self.train_updates_per_epoch = self.cfg.optim.updates_per_epoch
77
+ self.total_updates = self.train_updates_per_epoch * self.cfg.optim.epochs
78
+ # instantiate model & exponential moving average on the model
79
+ self.build_model()
80
+ self.logger.info("Model hash: %s", model_hash(self.model))
81
+ assert 'model' in self.stateful.sources, \
82
+ "Please register the model to stateful with self.register_stateful('model') in build_model."
83
+ self.profiler = Profiler(self.model, **self.cfg.profiler)
84
+ self.initialize_ema()
85
+ self.register_stateful('ema')
86
+ assert self.ema is None or 'ema' in self.stateful.sources, \
87
+ "Please register the ema to stateful with self.register_stateful('ema') in build_model."
88
+ self.deadlock_detect = DeadlockDetect(**self.cfg.deadlock)
89
+ # basic statistics on the trained model
90
+ model_size = sum(p.numel() for p in self.model.parameters() if p.requires_grad) / 1e6
91
+ # one copy of grad, one copy of momentum, one copy of denominator and model weights.
92
+ # and 4 bytes for each float!
93
+ mem_usage = model_size * 4 * 4 / 1000
94
+ self.logger.info("Model size: %.2f M params", model_size)
95
+ self.logger.info("Base memory usage, with model, grad and optim: %.2f GB", mem_usage)
96
+
97
+ @property
98
+ def autocast(self):
99
+ """Convenient autocast (or not) using the solver configuration."""
100
+ return TorchAutocast(enabled=self.cfg.autocast, device_type=self.device, dtype=self.autocast_dtype)
101
+
102
+ def _get_state_source(self, name) -> flashy.state.StateDictSource:
103
+ # Internal utility to get a state source from the solver
104
+ return self.stateful.sources[name]
105
+
106
+ @property
107
+ def best_metric_name(self) -> tp.Optional[str]:
108
+ """Metric name used to identify the best state. This metric should be stored in the metrics
109
+ used on the stage for best state identification (most likely, `valid`). If None, then
110
+ no best state is saved.
111
+ """
112
+ return None
113
+
114
+ def register_best_state(self, *args: str):
115
+ """Register state sources in `BestStateDictManager` to keep their best states along with their
116
+ latest states. The best state will be used at evaluation stages instead of the latest states.
117
+
118
+ Shortcut around `BestStateDictManager.register` method. You can pass any number of
119
+ attribute, included nested attributes and those will be included into the checkpoints
120
+ and automatically restored when `BaseSolver.restore` is called.
121
+ """
122
+ for name in args:
123
+ state_source = self._get_state_source(name)
124
+ assert name in self.stateful.sources, "Registered states in best should be registered in stateful first!"
125
+ self.best_state.register(name, state_source)
126
+
127
+ def register_ema(self, *args: str):
128
+ """Register state sources for exponential moving average.
129
+
130
+ The registered sources are used to instantiate a ModuleDictEMA instance.
131
+ The ModuleDictEMA keeps a `nn.ModuleDict` module that is updated when self.ema.step() is called
132
+ and swapped with the original state sources with self.swap_ema_state() method.
133
+
134
+ Usage:
135
+ self.register_ema('model')
136
+ """
137
+ assert self.ema is None, "Cannot register state source to already instantiated EMA."
138
+ for name in args:
139
+ self._ema_sources[name] = getattr(self, name)
140
+
141
+ def wrap_with_fsdp(self, model: torch.nn.Module, *args, **kwargs):
142
+ model = fsdp.wrap_with_fsdp(self.cfg.fsdp, model, *args, **kwargs)
143
+ if isinstance(model, fsdp.FSDP):
144
+ self._fsdp_modules.append(model)
145
+ return model
146
+
147
+ def update_best_state_from_stage(self, stage_name: str = 'valid'):
148
+ """Update latest best state based on pending metrics of a given stage. This method relies
149
+ on the `BestStateDictManager.update` method to update the best state_dict with latest weights
150
+ if the registered states happen to match to the best performing setup.
151
+ """
152
+ if self.best_metric_name is None:
153
+ # when no best metric is defined, the last state is always the best
154
+ self._new_best_state = True
155
+ self.logger.info("Updating best state with current state.")
156
+ else:
157
+ assert stage_name in self._pending_metrics, f"Metrics for stage {stage_name} not found."
158
+ assert self.best_metric_name in self._pending_metrics[stage_name], \
159
+ f"Best metric not found in {stage_name} metrics. Cannot register best state"
160
+ current_score = self._pending_metrics[stage_name][self.best_metric_name]
161
+ all_best_metric_scores = [
162
+ past_metrics[stage_name][self.best_metric_name]
163
+ for past_metrics in self.history
164
+ ]
165
+ all_best_metric_scores.append(current_score)
166
+ best_score = min(all_best_metric_scores)
167
+ self._new_best_state = current_score == best_score
168
+ if self._new_best_state:
169
+ old_best = min(all_best_metric_scores[:-1] + [float('inf')])
170
+ self.logger.info(
171
+ f"New best state with {self.best_metric_name}={current_score:.3f} (was {old_best:.3f})")
172
+
173
+ if self._new_best_state:
174
+ if self.cfg.fsdp.use:
175
+ # this will give an empty state dict on all ranks but the rank 0
176
+ # which will have a copy in memory of the full model.
177
+ with fsdp.switch_to_full_state_dict(self._fsdp_modules):
178
+ for name in self.best_state.states.keys():
179
+ state_source = self._get_state_source(name)
180
+ self.best_state.update(name, state_source)
181
+ # we save to a different dict.
182
+ self.fsdp_best_state.update(self.best_state.state_dict())
183
+ # We cannot efficiently load fsdp_best_state when using FSDP,
184
+ # so we have do do a second pass, with the local shards.
185
+ for name in self.best_state.states.keys():
186
+ state_source = self._get_state_source(name)
187
+ self.best_state.update(name, state_source)
188
+
189
+ def _load_new_state_dict(self, state_dict: dict) -> dict:
190
+ old_states = {}
191
+ for name, new_state in state_dict.items():
192
+ state_source = self._get_state_source(name)
193
+ old_states[name] = copy_state(state_source.state_dict())
194
+ state_source.load_state_dict(new_state)
195
+ return old_states
196
+
197
+ @contextmanager
198
+ def swap_best_state(self):
199
+ self.logger.debug(f"Swapping to best state for: {', '.join(self.best_state.state_dict().keys())}")
200
+ old_states = self._load_new_state_dict(self.best_state.state_dict())
201
+ try:
202
+ yield
203
+ finally:
204
+ self.logger.debug("Swapping back from best to original state")
205
+ for name, old_state in old_states.items():
206
+ state_source = self._get_state_source(name)
207
+ state_source.load_state_dict(old_state)
208
+
209
+ @contextmanager
210
+ def swap_ema_state(self):
211
+ if self.ema is None:
212
+ yield
213
+ else:
214
+ ema_state_dict = self.ema.state_dict()['state']
215
+ self.logger.debug(f"Swapping to EMA state for: {', '.join(ema_state_dict.keys())}")
216
+ old_states = self._load_new_state_dict(ema_state_dict)
217
+ try:
218
+ yield
219
+ finally:
220
+ self.logger.debug("Swapping back from EMA state to original state")
221
+ for name, old_state in old_states.items():
222
+ state_source = self._get_state_source(name)
223
+ state_source.load_state_dict(old_state)
224
+
225
+ @property
226
+ def is_training(self):
227
+ return self.current_stage == 'train'
228
+
229
+ def log_model_summary(self, model: nn.Module):
230
+ """Log model summary, architecture and size of the model."""
231
+ self.logger.info(model)
232
+ mb = sum(p.numel() for p in model.parameters()) * 4 / 2 ** 20
233
+ self.logger.info("Size: %.1f MB", mb)
234
+
235
+ @abstractmethod
236
+ def build_model(self):
237
+ """Method to implement to initialize model."""
238
+ ...
239
+
240
+ def initialize_ema(self):
241
+ """Initialize exponential moving average with the registered sources.
242
+ EMA object is created if the optim.ema.model.decay value is non-null.
243
+ """
244
+ from .builders import get_ema
245
+ self.ema = get_ema(self._ema_sources, self.cfg.optim.ema)
246
+ if self.ema is None:
247
+ self.logger.info('No EMA on the model.')
248
+ else:
249
+ assert self.cfg.optim.ema.updates > 0
250
+ self.logger.info(
251
+ f'Initializing EMA on the model with decay = {self.ema.decay}'
252
+ f' every {self.cfg.optim.ema.updates} updates'
253
+ )
254
+
255
+ @abstractmethod
256
+ def build_dataloaders(self):
257
+ """Method to implement to initialize dataloaders."""
258
+ ...
259
+
260
+ @abstractmethod
261
+ def show(self):
262
+ """Method to log any information without running the job."""
263
+ ...
264
+
265
+ @property
266
+ def log_updates(self):
267
+ # convenient access to log updates
268
+ return self._log_updates
269
+
270
+ def checkpoint_path(self, **kwargs):
271
+ kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
272
+ return self.folder / checkpoint.checkpoint_name(**kwargs)
273
+
274
+ def epoch_checkpoint_path(self, epoch: int, **kwargs):
275
+ kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
276
+ return self.folder / checkpoint.checkpoint_name(str(epoch), **kwargs)
277
+
278
+ def checkpoint_path_with_name(self, name: str, **kwargs):
279
+ kwargs.setdefault('use_fsdp', self.cfg.fsdp.use)
280
+ return self.folder / checkpoint.checkpoint_name(name=name, **kwargs)
281
+
282
+ def save_checkpoints(self):
283
+ """Save checkpoint, optionally keeping a copy for a given epoch."""
284
+ is_sharded = self.cfg.fsdp.use
285
+ if not flashy.distrib.is_rank_zero() and not is_sharded:
286
+ return
287
+ self.logger.info("Model hash: %s", model_hash(self.model))
288
+ state = self.state_dict()
289
+ epoch = self.epoch - 1 # pushing metrics will increase the epoch in Flashy, so we do -1 here
290
+
291
+ # save minimal state_dict as new checkpoint every X epoch
292
+ if self.cfg.checkpoint.save_every:
293
+ if epoch % self.cfg.checkpoint.save_every == 0:
294
+ minimal_state = state
295
+ if self.cfg.checkpoint.keep_every_states is not None and len(self.cfg.checkpoint.keep_every_states) > 0:
296
+ minimal_state = {
297
+ name: source for name, source in state.items()
298
+ if name in self.cfg.checkpoint.keep_every_states
299
+ }
300
+ epoch_checkpoint_path = self.epoch_checkpoint_path(epoch)
301
+ checkpoint.save_checkpoint(minimal_state, epoch_checkpoint_path, is_sharded)
302
+
303
+ # save checkpoint as latest checkpoint
304
+ if self.cfg.checkpoint.save_last:
305
+ last_checkpoint_path = self.checkpoint_path()
306
+ checkpoint.save_checkpoint(state, last_checkpoint_path, is_sharded)
307
+
308
+ # flush any stale checkpoint to reduce disk footprint
309
+ checkpoint.flush_stale_checkpoints(self.checkpoint_path())
310
+
311
+ def load_from_pretrained(self, name: str) -> dict:
312
+ raise NotImplementedError("Solver does not provide a way to load pretrained models.")
313
+
314
+ def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[str] = []) -> tp.Optional[dict]:
315
+ """Load last checkpoint or the one specified in continue_from.
316
+
317
+ Args:
318
+ load_best (bool): Whether to load from best state dict or not.
319
+ Best state dict is always used when not loading the current xp.
320
+ ignore_state_keys (list of str): List of sources to ignore when loading the state, e.g. `optimizer`.
321
+ Returns:
322
+ state (dict, optional): The loaded state dictionary.
323
+ """
324
+ # load checkpoints from xp folder or cfg.continue_from
325
+ is_sharded = self.cfg.fsdp.use
326
+ load_from_path: tp.Optional[Path] = None
327
+ checkpoint_source: tp.Optional[checkpoint.CheckpointSource] = None
328
+
329
+ if load_best:
330
+ self.logger.info("Trying to load state_dict from best state.")
331
+
332
+ state: tp.Optional[dict] = None
333
+ rank0_checkpoint_path = self.checkpoint_path(use_fsdp=False)
334
+ current_checkpoint_path = self.checkpoint_path()
335
+ _pretrained_prefix = '//pretrained/'
336
+ continue_pretrained = (self.cfg.continue_from or '').startswith(_pretrained_prefix)
337
+ if rank0_checkpoint_path.exists():
338
+ self.logger.info(f"Loading existing checkpoint: {current_checkpoint_path}")
339
+ load_from_path = current_checkpoint_path
340
+ checkpoint.check_sharded_checkpoint(current_checkpoint_path, rank0_checkpoint_path)
341
+ checkpoint_source = checkpoint.CheckpointSource.CURRENT_XP
342
+ elif self.cfg.continue_from and not continue_pretrained:
343
+ self.logger.info(f"Continuing from provided checkpoint: {self.cfg.continue_from}")
344
+ # we're always continuing from consolidated checkpoints: self.cfg.use_fsdp and not continue_best
345
+ load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from, use_fsdp=False)
346
+ if load_from_path is None:
347
+ self.logger.error('Could not resolve the continue_from checkpoint %s', self.cfg.continue_from)
348
+ raise RuntimeError(f'Could not resolve continue_from checkpoint {self.cfg.continue_from}')
349
+ checkpoint_source = checkpoint.CheckpointSource.OTHER
350
+
351
+ if load_from_path is not None:
352
+ state = checkpoint.load_checkpoint(load_from_path, is_sharded)
353
+ elif continue_pretrained:
354
+ self.logger.info("Loading a pretrained model. Ignoring 'load_best' and 'ignore_state_keys' params.")
355
+ state = self.load_from_pretrained(self.cfg.continue_from[len(_pretrained_prefix):])
356
+ checkpoint_source = checkpoint.CheckpointSource.PRETRAINED
357
+ load_best = True
358
+
359
+ # checkpoints are not from the current xp, we only retrieve the best state
360
+ if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
361
+ assert state is not None
362
+ self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
363
+ load_best = True
364
+ state = {key: state[key] for key in self._continue_best_source_keys if key in state}
365
+ # loaded checkpoints are FSDP checkpoints: we're reading the best state
366
+ # from FSDP and we drop the regular best_state
367
+ if 'fsdp_best_state' in state and state['fsdp_best_state']:
368
+ state.pop('best_state', None)
369
+ self.logger.info("... Loaded checkpoint has FSDP best state")
370
+ # FSDP is enabled in the solver, if the loaded checkpoints do not have FSDP support
371
+ # then we're initializing FSDP best state with the regular best state
372
+ elif self.cfg.fsdp.use:
373
+ if 'fsdp_best_state' not in state or not state['fsdp_best_state']:
374
+ # we swap non-FSDP checkpoints best_state to FSDP-compatible best state
375
+ state['fsdp_best_state'] = state.pop('best_state')
376
+ self.logger.info("... Loaded checkpoint does not have FSDP best state. Use regular best state")
377
+
378
+ if state is not None:
379
+ if load_best:
380
+ self.logger.info("Ignoring keys when loading best %r", ignore_state_keys)
381
+ for key in set(ignore_state_keys):
382
+ if key in state:
383
+ state.pop(key)
384
+ has_best_state = 'best_state' in state or 'fsdp_best_state' in state
385
+ assert has_best_state, ("Trying to load best state but neither 'best_state'",
386
+ " or 'fsdp_best_state' found in checkpoints.")
387
+ self.load_state_dict(state)
388
+
389
+ # for FSDP, let's make extra sure nothing bad happened with out of sync
390
+ # checkpoints across workers.
391
+ epoch = float(self.epoch)
392
+ avg_epoch = flashy.distrib.average_metrics({'epoch': epoch})['epoch']
393
+ if avg_epoch != epoch:
394
+ raise RuntimeError(
395
+ f"Inconsistent loading of checkpoints happened, our epoch is {epoch} "
396
+ f"but average of epochs is {avg_epoch}, at least one gpu must have a "
397
+ "different epoch number.")
398
+
399
+ # on load_best, properly reinitialize state_dict, best states and ema
400
+ # otherwise we load from the current xp and don't alter anything
401
+ if load_best:
402
+ self.logger.info("Loading state_dict from best state.")
403
+ if not self.cfg.fsdp.use and self.fsdp_best_state:
404
+ # loading from an FSDP checkpoint but with FSDP deactivated
405
+ self.logger.info("... Loading from FSDP best state dict.")
406
+ self.best_state.load_state_dict(self.fsdp_best_state)
407
+
408
+ # if load_best, we permanently override the regular state_dict with the best state
409
+ if self.cfg.fsdp.use:
410
+ self.logger.info("FSDP is used, loading from FSDP best state.")
411
+ with fsdp.switch_to_full_state_dict(self._fsdp_modules):
412
+ # this might be really fragile but okay for now.
413
+ self.load_state_dict(self.fsdp_best_state)
414
+ else:
415
+ # we permanently swap the stateful objects to their best state
416
+ self._load_new_state_dict(self.best_state.state_dict())
417
+
418
+ # the EMA modules should also be instantiated with best state.
419
+ # the easiest way to do so is to reinitialize a new EMA with best state loaded.
420
+ if self.ema is not None:
421
+ self.logger.info("Re-initializing EMA from best state")
422
+ self.initialize_ema()
423
+
424
+ if self.cfg.fsdp.use:
425
+ self.logger.info("Re-initializing best state after using FSDP best state.")
426
+ for name in self.best_state.states.keys():
427
+ state_source = self._get_state_source(name)
428
+ self.best_state.update(name, state_source)
429
+
430
+ return state
431
+
432
+ def restore(self, load_best: bool = False, replay_metrics: bool = False,
433
+ ignore_state_keys: tp.List[str] = []) -> bool:
434
+ """Restore the status of a solver for a given xp.
435
+
436
+ Args:
437
+ load_best (bool): if `True`, load the best state from the checkpoint.
438
+ replay_metrics (bool): if `True`, logs all the metrics from past epochs.
439
+ ignore_state_keys (list of str): list of sources to ignore when loading the state, e.g. `optimizer`.
440
+ """
441
+ self.logger.info("Restoring weights and history.")
442
+ restored_checkpoints = self.load_checkpoints(load_best, ignore_state_keys)
443
+
444
+ self.logger.info("Model hash: %s", model_hash(self.model))
445
+
446
+ if replay_metrics and len(self.history) > 0:
447
+ self.logger.info("Replaying past metrics...")
448
+ for epoch, stages in enumerate(self.history):
449
+ for stage_name, metrics in stages.items():
450
+ # We manually log the metrics summary to the result logger
451
+ # as we don't want to add them to the pending metrics
452
+ self.result_logger._log_summary(stage_name, metrics, step=epoch + 1, step_name='epoch',
453
+ formatter=self.get_formatter(stage_name))
454
+ return restored_checkpoints is not None
455
+
456
+ def commit(self, save_checkpoints: bool = True):
457
+ """Commit metrics to dora and save checkpoints at the end of an epoch."""
458
+ # we override commit to introduce more complex checkpoint saving behaviors
459
+ self.history.append(self._pending_metrics) # This will increase self.epoch
460
+ if save_checkpoints:
461
+ self.save_checkpoints()
462
+ self._start_epoch()
463
+ if flashy.distrib.is_rank_zero():
464
+ self.xp.link.update_history(self.history)
465
+
466
+ def run_epoch(self):
467
+ """Run a single epoch with all stages.
468
+
469
+ Metrics for a given stage are stored in _pending_metrics and committed by the solver afterwards.
470
+ Children solvers can extend this method with custom behavior, e.g.:
471
+
472
+ def run_epoch(self):
473
+ ... # custom code
474
+ super().run_epoch()
475
+ ... # custom code
476
+ """
477
+ self.run_stage('train', self.train)
478
+ with torch.no_grad():
479
+ with self.swap_ema_state():
480
+ self.run_stage('valid', self.valid)
481
+ # the best state is updated with EMA states if available
482
+ self.update_best_state_from_stage('valid')
483
+ with self.swap_best_state():
484
+ if self.should_run_stage('evaluate'):
485
+ self.run_stage('evaluate', self.evaluate)
486
+ if self.should_run_stage('generate'):
487
+ self.run_stage('generate', with_rank_rng()(self.generate))
488
+
489
+ def run(self):
490
+ """Training loop."""
491
+ assert len(self.state_dict()) > 0
492
+ self.restore(replay_metrics=True) # load checkpoint and replay history
493
+ self.log_hyperparams(dict_from_config(self.cfg))
494
+ for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
495
+ if self.should_stop_training():
496
+ return
497
+ self.run_epoch()
498
+ # Commit will send the metrics to Dora and save checkpoints by default.
499
+ self.commit()
500
+
501
+ def should_stop_training(self) -> bool:
502
+ """Check whether we should stop training or not."""
503
+ return self.epoch > self.cfg.optim.epochs
504
+
505
+ def should_run_stage(self, stage_name) -> bool:
506
+ """Check whether we want to run the specified stages."""
507
+ stage_every = self.cfg[stage_name].get('every', None)
508
+ is_last_epoch = self.epoch == self.cfg.optim.epochs
509
+ is_epoch_every = (stage_every and self.epoch % stage_every == 0)
510
+ return is_last_epoch or is_epoch_every
511
+
512
+ @abstractmethod
513
+ def run_step(self, idx: int, batch: tp.Any, metrics: dict):
514
+ """Perform one training or valid step on a given batch."""
515
+ ...
516
+
517
+ def common_train_valid(self, dataset_split: str, **kwargs: tp.Any):
518
+ """Common logic for train and valid stages."""
519
+ self.model.train(self.is_training)
520
+
521
+ loader = self.dataloaders[dataset_split]
522
+ # get a different order for distributed training, otherwise this will get ignored
523
+ if flashy.distrib.world_size() > 1 \
524
+ and isinstance(loader.sampler, torch.utils.data.distributed.DistributedSampler):
525
+ loader.sampler.set_epoch(self.epoch)
526
+ updates_per_epoch = self.train_updates_per_epoch if self.is_training else len(loader)
527
+ if self.cfg.benchmark_no_load:
528
+ self.logger.warning("Fake loading for benchmarking: re-using first batch")
529
+ batch = next(iter(loader))
530
+ loader = [batch] * updates_per_epoch # type: ignore
531
+ lp = self.log_progress(self.current_stage, loader, total=updates_per_epoch, updates=self.log_updates)
532
+ average = flashy.averager() # epoch wise average
533
+ instant_average = flashy.averager() # average between two logging
534
+ metrics: dict = {}
535
+
536
+ with self.profiler, self.deadlock_detect: # profiler will only run for the first 20 updates.
537
+ for idx, batch in enumerate(lp):
538
+ self.deadlock_detect.update('batch')
539
+ if idx >= updates_per_epoch:
540
+ break
541
+ metrics = {}
542
+ metrics = self.run_step(idx, batch, metrics)
543
+ self.deadlock_detect.update('step')
544
+ # run EMA step
545
+ if self.ema is not None and self.is_training and (idx + 1) % self.cfg.optim.ema.updates == 0:
546
+ self.logger.debug("EMA model step")
547
+ self.ema.step()
548
+ self.deadlock_detect.update('ema')
549
+ self.profiler.step()
550
+ instant_metrics = instant_average(metrics)
551
+ if lp.update(**instant_metrics):
552
+ instant_average = flashy.averager() # reset averager between two logging
553
+ metrics = average(metrics) # epoch wise average
554
+ self.deadlock_detect.update('end_batch')
555
+
556
+ metrics = flashy.distrib.average_metrics(metrics, updates_per_epoch)
557
+ return metrics
558
+
559
+ def train(self):
560
+ """Train stage."""
561
+ return self.common_train_valid('train')
562
+
563
+ def valid(self):
564
+ """Valid stage."""
565
+ return self.common_train_valid('valid')
566
+
567
+ @abstractmethod
568
+ def evaluate(self):
569
+ """Evaluate stage."""
570
+ ...
571
+
572
+ @abstractmethod
573
+ def generate(self):
574
+ """Generate stage."""
575
+ ...
576
+
577
+ def run_one_stage(self, stage_name: str):
578
+ """Run only the specified stage.
579
+ This method is useful to only generate samples from a trained experiment
580
+ or rerun the validation or evaluation stages.
581
+ """
582
+ fn = {
583
+ 'generate': with_rank_rng()(self.generate),
584
+ 'evaluate': self.evaluate,
585
+ 'valid': self.valid,
586
+ }
587
+ if stage_name not in fn:
588
+ raise ValueError(f'Trying to run stage {stage_name} is not supported.')
589
+ assert len(self.state_dict()) > 0
590
+ self._start_epoch()
591
+ with torch.no_grad(), self.swap_best_state():
592
+ self.run_stage(stage_name, fn[stage_name])
593
+ if not self.cfg.execute_inplace:
594
+ self.commit(save_checkpoints=False)
595
+
596
+ @staticmethod
597
+ def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
598
+ device: tp.Optional[str] = None, autocast: bool = True,
599
+ batch_size: tp.Optional[int] = None,
600
+ override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
601
+ **kwargs):
602
+ """Mostly a convenience function around audiocraft.train.get_solver_from_sig,
603
+ populating all the proper param, deactivating EMA, FSDP, loading the best state,
604
+ basically all you need to get a solver ready to "play" with in single GPU mode
605
+ and with minimal memory overhead.
606
+
607
+ Args:
608
+ sig (str): signature to load.
609
+ dtype (str or None): potential dtype, as a string, i.e. 'float16'.
610
+ device (str or None): potential device, as a string, i.e. 'cuda'.
611
+ override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
612
+ """
613
+ from audiocraft import train
614
+ our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
615
+ our_override_cfg['autocast'] = autocast
616
+ if dtype is not None:
617
+ our_override_cfg['dtype'] = dtype
618
+ if device is not None:
619
+ our_override_cfg['device'] = device
620
+ if batch_size is not None:
621
+ our_override_cfg['dataset'] = {'batch_size': batch_size}
622
+ if override_cfg is None:
623
+ override_cfg = {}
624
+ override_cfg = omegaconf.OmegaConf.merge(
625
+ omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
626
+ solver = train.get_solver_from_sig(
627
+ sig, override_cfg=override_cfg,
628
+ load_best=True, disable_fsdp=True,
629
+ ignore_state_keys=['optimizer', 'ema'], **kwargs)
630
+ solver.model.eval()
631
+ return solver
audiocraft/solvers/builders.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ All the functions to build the relevant solvers and used objects
9
+ from the Hydra config.
10
+ """
11
+
12
+ from enum import Enum
13
+ import logging
14
+ import typing as tp
15
+
16
+ import dora
17
+ import flashy
18
+ import omegaconf
19
+ import torch
20
+ from torch import nn
21
+ from torch.optim import Optimizer
22
+ # LRScheduler was renamed in some torch versions
23
+ try:
24
+ from torch.optim.lr_scheduler import LRScheduler # type: ignore
25
+ except ImportError:
26
+ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
27
+
28
+ from .base import StandardSolver
29
+ from .. import adversarial, data, losses, metrics, optim
30
+ from ..utils.utils import dict_from_config, get_loader
31
+
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class DatasetType(Enum):
37
+ AUDIO = "audio"
38
+ MUSIC = "music"
39
+ SOUND = "sound"
40
+
41
+
42
+ def get_solver(cfg: omegaconf.DictConfig) -> StandardSolver:
43
+ """Instantiate solver from config."""
44
+ from .audiogen import AudioGenSolver
45
+ from .compression import CompressionSolver
46
+ from .musicgen import MusicGenSolver
47
+ from .diffusion import DiffusionSolver
48
+ from .magnet import MagnetSolver, AudioMagnetSolver
49
+ klass = {
50
+ 'compression': CompressionSolver,
51
+ 'musicgen': MusicGenSolver,
52
+ 'audiogen': AudioGenSolver,
53
+ 'magnet': MagnetSolver,
54
+ 'audio_magnet': AudioMagnetSolver,
55
+ 'lm': MusicGenSolver, # backward compatibility
56
+ 'diffusion': DiffusionSolver,
57
+ 'sound_lm': AudioGenSolver, # backward compatibility
58
+ }[cfg.solver]
59
+ return klass(cfg) # type: ignore
60
+
61
+
62
+ def get_optim_parameter_groups(model: nn.Module):
63
+ """Create parameter groups for the model using the appropriate method
64
+ if defined for each modules, to create the different groups.
65
+
66
+ Args:
67
+ model (nn.Module): torch model
68
+ Returns:
69
+ List of parameter groups
70
+ """
71
+ seen_params: tp.Set[nn.parameter.Parameter] = set()
72
+ other_params = []
73
+ groups = []
74
+ for name, module in model.named_modules():
75
+ if hasattr(module, 'make_optim_group'):
76
+ group = module.make_optim_group()
77
+ params = set(group['params'])
78
+ assert params.isdisjoint(seen_params)
79
+ seen_params |= set(params)
80
+ groups.append(group)
81
+ for param in model.parameters():
82
+ if param not in seen_params:
83
+ other_params.append(param)
84
+ groups.insert(0, {'params': other_params})
85
+ parameters = groups
86
+ return parameters
87
+
88
+
89
+ def get_optimizer(params: tp.Union[nn.Module, tp.Iterable[torch.Tensor]], cfg: omegaconf.DictConfig) -> Optimizer:
90
+ """Build torch optimizer from config and set of parameters.
91
+ Supported optimizers: Adam, AdamW
92
+
93
+ Args:
94
+ params (nn.Module or iterable of torch.Tensor): Parameters to optimize.
95
+ cfg (DictConfig): Optimization-related configuration.
96
+ Returns:
97
+ torch.optim.Optimizer.
98
+ """
99
+ if 'optimizer' not in cfg:
100
+ if getattr(cfg, 'optim', None) is not None:
101
+ raise KeyError("Optimizer not found in config. Try instantiating optimizer from cfg.optim?")
102
+ else:
103
+ raise KeyError("Optimizer not found in config.")
104
+
105
+ parameters = get_optim_parameter_groups(params) if isinstance(params, nn.Module) else params
106
+ optimizer: torch.optim.Optimizer
107
+ if cfg.optimizer == 'adam':
108
+ optimizer = torch.optim.Adam(parameters, lr=cfg.lr, **cfg.adam)
109
+ elif cfg.optimizer == 'adamw':
110
+ optimizer = torch.optim.AdamW(parameters, lr=cfg.lr, **cfg.adam)
111
+ elif cfg.optimizer == 'dadam':
112
+ optimizer = optim.DAdaptAdam(parameters, lr=cfg.lr, **cfg.adam)
113
+ else:
114
+ raise ValueError(f"Unsupported Optimizer: {cfg.optimizer}")
115
+ return optimizer
116
+
117
+
118
+ def get_lr_scheduler(optimizer: torch.optim.Optimizer,
119
+ cfg: omegaconf.DictConfig,
120
+ total_updates: int) -> tp.Optional[LRScheduler]:
121
+ """Build torch learning rate scheduler from config and associated optimizer.
122
+ Supported learning rate schedulers: ExponentialLRScheduler, PlateauLRScheduler
123
+
124
+ Args:
125
+ optimizer (torch.optim.Optimizer): Optimizer.
126
+ cfg (DictConfig): Schedule-related configuration.
127
+ total_updates (int): Total number of updates.
128
+ Returns:
129
+ torch.optim.Optimizer.
130
+ """
131
+ if 'lr_scheduler' not in cfg:
132
+ raise KeyError("LR Scheduler not found in config")
133
+
134
+ lr_sched: tp.Optional[LRScheduler] = None
135
+ if cfg.lr_scheduler == 'step':
136
+ lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, **cfg.step)
137
+ elif cfg.lr_scheduler == 'exponential':
138
+ lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=cfg.exponential)
139
+ elif cfg.lr_scheduler == 'cosine':
140
+ kwargs = dict_from_config(cfg.cosine)
141
+ warmup_steps = kwargs.pop('warmup')
142
+ lr_sched = optim.CosineLRScheduler(
143
+ optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
144
+ elif cfg.lr_scheduler == 'polynomial_decay':
145
+ kwargs = dict_from_config(cfg.polynomial_decay)
146
+ warmup_steps = kwargs.pop('warmup')
147
+ lr_sched = optim.PolynomialDecayLRScheduler(
148
+ optimizer, warmup_steps=warmup_steps, total_steps=total_updates, **kwargs)
149
+ elif cfg.lr_scheduler == 'inverse_sqrt':
150
+ kwargs = dict_from_config(cfg.inverse_sqrt)
151
+ warmup_steps = kwargs.pop('warmup')
152
+ lr_sched = optim.InverseSquareRootLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
153
+ elif cfg.lr_scheduler == 'linear_warmup':
154
+ kwargs = dict_from_config(cfg.linear_warmup)
155
+ warmup_steps = kwargs.pop('warmup')
156
+ lr_sched = optim.LinearWarmupLRScheduler(optimizer, warmup_steps=warmup_steps, **kwargs)
157
+ elif cfg.lr_scheduler is not None:
158
+ raise ValueError(f"Unsupported LR Scheduler: {cfg.lr_scheduler}")
159
+ return lr_sched
160
+
161
+
162
+ def get_ema(module_dict: nn.ModuleDict, cfg: omegaconf.DictConfig) -> tp.Optional[optim.ModuleDictEMA]:
163
+ """Initialize Exponential Moving Average.
164
+
165
+ Args:
166
+ module_dict (nn.ModuleDict): ModuleDict for which to compute the EMA.
167
+ cfg (omegaconf.DictConfig): Optim EMA configuration.
168
+ Returns:
169
+ optim.ModuleDictEMA: EMA version of the ModuleDict.
170
+ """
171
+ kw: tp.Dict[str, tp.Any] = dict(cfg)
172
+ use = kw.pop('use', False)
173
+ decay = kw.pop('decay', None)
174
+ device = kw.pop('device', None)
175
+ if not use:
176
+ return None
177
+ if len(module_dict) == 0:
178
+ raise ValueError("Trying to build EMA but an empty module_dict source is provided!")
179
+ ema_module = optim.ModuleDictEMA(module_dict, decay=decay, device=device)
180
+ return ema_module
181
+
182
+
183
+ def get_loss(loss_name: str, cfg: omegaconf.DictConfig):
184
+ """Instantiate loss from configuration."""
185
+ klass = {
186
+ 'l1': torch.nn.L1Loss,
187
+ 'l2': torch.nn.MSELoss,
188
+ 'mel': losses.MelSpectrogramL1Loss,
189
+ 'mrstft': losses.MRSTFTLoss,
190
+ 'msspec': losses.MultiScaleMelSpectrogramLoss,
191
+ 'sisnr': losses.SISNR,
192
+ }[loss_name]
193
+ kwargs = dict(getattr(cfg, loss_name))
194
+ return klass(**kwargs)
195
+
196
+
197
+ def get_balancer(loss_weights: tp.Dict[str, float], cfg: omegaconf.DictConfig) -> losses.Balancer:
198
+ """Instantiate loss balancer from configuration for the provided weights."""
199
+ kwargs: tp.Dict[str, tp.Any] = dict_from_config(cfg)
200
+ return losses.Balancer(loss_weights, **kwargs)
201
+
202
+
203
+ def get_adversary(name: str, cfg: omegaconf.DictConfig) -> nn.Module:
204
+ """Initialize adversary from config."""
205
+ klass = {
206
+ 'msd': adversarial.MultiScaleDiscriminator,
207
+ 'mpd': adversarial.MultiPeriodDiscriminator,
208
+ 'msstftd': adversarial.MultiScaleSTFTDiscriminator,
209
+ }[name]
210
+ adv_cfg: tp.Dict[str, tp.Any] = dict(getattr(cfg, name))
211
+ return klass(**adv_cfg)
212
+
213
+
214
+ def get_adversarial_losses(cfg) -> nn.ModuleDict:
215
+ """Initialize dict of adversarial losses from config."""
216
+ device = cfg.device
217
+ adv_cfg = getattr(cfg, 'adversarial')
218
+ adversaries = adv_cfg.get('adversaries', [])
219
+ adv_loss_name = adv_cfg['adv_loss']
220
+ feat_loss_name = adv_cfg.get('feat_loss')
221
+ normalize = adv_cfg.get('normalize', True)
222
+ feat_loss: tp.Optional[adversarial.FeatureMatchingLoss] = None
223
+ if feat_loss_name:
224
+ assert feat_loss_name in ['l1', 'l2'], f"Feature loss only support L1 or L2 but {feat_loss_name} found."
225
+ loss = get_loss(feat_loss_name, cfg)
226
+ feat_loss = adversarial.FeatureMatchingLoss(loss, normalize)
227
+ loss = adversarial.get_adv_criterion(adv_loss_name)
228
+ loss_real = adversarial.get_real_criterion(adv_loss_name)
229
+ loss_fake = adversarial.get_fake_criterion(adv_loss_name)
230
+ adv_losses = nn.ModuleDict()
231
+ for adv_name in adversaries:
232
+ adversary = get_adversary(adv_name, cfg).to(device)
233
+ optimizer = get_optimizer(adversary.parameters(), cfg.optim)
234
+ adv_loss = adversarial.AdversarialLoss(
235
+ adversary,
236
+ optimizer,
237
+ loss=loss,
238
+ loss_real=loss_real,
239
+ loss_fake=loss_fake,
240
+ loss_feat=feat_loss,
241
+ normalize=normalize
242
+ )
243
+ adv_losses[adv_name] = adv_loss
244
+ return adv_losses
245
+
246
+
247
+ def get_visqol(cfg: omegaconf.DictConfig) -> metrics.ViSQOL:
248
+ """Instantiate ViSQOL metric from config."""
249
+ kwargs = dict_from_config(cfg)
250
+ return metrics.ViSQOL(**kwargs)
251
+
252
+
253
+ def get_fad(cfg: omegaconf.DictConfig) -> metrics.FrechetAudioDistanceMetric:
254
+ """Instantiate Frechet Audio Distance metric from config."""
255
+ kwargs = dict_from_config(cfg.tf)
256
+ xp = dora.get_xp()
257
+ kwargs['log_folder'] = xp.folder
258
+ return metrics.FrechetAudioDistanceMetric(**kwargs)
259
+
260
+
261
+ def get_kldiv(cfg: omegaconf.DictConfig) -> metrics.KLDivergenceMetric:
262
+ """Instantiate KL-Divergence metric from config."""
263
+ kld_metrics = {
264
+ 'passt': metrics.PasstKLDivergenceMetric,
265
+ }
266
+ klass = kld_metrics[cfg.model]
267
+ kwargs = dict_from_config(cfg.get(cfg.model))
268
+ return klass(**kwargs)
269
+
270
+
271
+ def get_text_consistency(cfg: omegaconf.DictConfig) -> metrics.TextConsistencyMetric:
272
+ """Instantiate Text Consistency metric from config."""
273
+ text_consistency_metrics = {
274
+ 'clap': metrics.CLAPTextConsistencyMetric
275
+ }
276
+ klass = text_consistency_metrics[cfg.model]
277
+ kwargs = dict_from_config(cfg.get(cfg.model))
278
+ return klass(**kwargs)
279
+
280
+
281
+ def get_chroma_cosine_similarity(cfg: omegaconf.DictConfig) -> metrics.ChromaCosineSimilarityMetric:
282
+ """Instantiate Chroma Cosine Similarity metric from config."""
283
+ assert cfg.model == 'chroma_base', "Only support 'chroma_base' method for chroma cosine similarity metric"
284
+ kwargs = dict_from_config(cfg.get(cfg.model))
285
+ return metrics.ChromaCosineSimilarityMetric(**kwargs)
286
+
287
+
288
+ def get_audio_datasets(cfg: omegaconf.DictConfig,
289
+ dataset_type: DatasetType = DatasetType.AUDIO) -> tp.Dict[str, torch.utils.data.DataLoader]:
290
+ """Build AudioDataset from configuration.
291
+
292
+ Args:
293
+ cfg (omegaconf.DictConfig): Configuration.
294
+ dataset_type: The type of dataset to create.
295
+ Returns:
296
+ dict[str, torch.utils.data.DataLoader]: Map of dataloader for each data split.
297
+ """
298
+ dataloaders: dict = {}
299
+
300
+ sample_rate = cfg.sample_rate
301
+ channels = cfg.channels
302
+ seed = cfg.seed
303
+ max_sample_rate = cfg.datasource.max_sample_rate
304
+ max_channels = cfg.datasource.max_channels
305
+
306
+ assert cfg.dataset is not None, "Could not find dataset definition in config"
307
+
308
+ dataset_cfg = dict_from_config(cfg.dataset)
309
+ splits_cfg: dict = {}
310
+ splits_cfg['train'] = dataset_cfg.pop('train')
311
+ splits_cfg['valid'] = dataset_cfg.pop('valid')
312
+ splits_cfg['evaluate'] = dataset_cfg.pop('evaluate')
313
+ splits_cfg['generate'] = dataset_cfg.pop('generate')
314
+ execute_only_stage = cfg.get('execute_only', None)
315
+
316
+ for split, path in cfg.datasource.items():
317
+ if not isinstance(path, str):
318
+ continue # skipping this as not a path
319
+ if execute_only_stage is not None and split != execute_only_stage:
320
+ continue
321
+ logger.info(f"Loading audio data split {split}: {str(path)}")
322
+ assert (
323
+ cfg.sample_rate <= max_sample_rate
324
+ ), f"Expecting a max sample rate of {max_sample_rate} for datasource but {sample_rate} found."
325
+ assert (
326
+ cfg.channels <= max_channels
327
+ ), f"Expecting a max number of channels of {max_channels} for datasource but {channels} found."
328
+
329
+ split_cfg = splits_cfg[split]
330
+ split_kwargs = {k: v for k, v in split_cfg.items()}
331
+ kwargs = {**dataset_cfg, **split_kwargs} # split kwargs overrides default dataset_cfg
332
+ kwargs['sample_rate'] = sample_rate
333
+ kwargs['channels'] = channels
334
+
335
+ if kwargs.get('permutation_on_files') and cfg.optim.updates_per_epoch:
336
+ kwargs['num_samples'] = (
337
+ flashy.distrib.world_size() * cfg.dataset.batch_size * cfg.optim.updates_per_epoch)
338
+
339
+ num_samples = kwargs['num_samples']
340
+ shuffle = kwargs['shuffle']
341
+
342
+ return_info = kwargs.pop('return_info')
343
+ batch_size = kwargs.pop('batch_size', None)
344
+ num_workers = kwargs.pop('num_workers')
345
+
346
+ if dataset_type == DatasetType.MUSIC:
347
+ dataset = data.music_dataset.MusicDataset.from_meta(path, **kwargs)
348
+ elif dataset_type == DatasetType.SOUND:
349
+ dataset = data.sound_dataset.SoundDataset.from_meta(path, **kwargs)
350
+ elif dataset_type == DatasetType.AUDIO:
351
+ dataset = data.info_audio_dataset.InfoAudioDataset.from_meta(path, return_info=return_info, **kwargs)
352
+ else:
353
+ raise ValueError(f"Dataset type is unsupported: {dataset_type}")
354
+
355
+ loader = get_loader(
356
+ dataset,
357
+ num_samples,
358
+ batch_size=batch_size,
359
+ num_workers=num_workers,
360
+ seed=seed,
361
+ collate_fn=dataset.collater if return_info else None,
362
+ shuffle=shuffle,
363
+ )
364
+ dataloaders[split] = loader
365
+
366
+ return dataloaders
audiocraft/solvers/compression.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import multiprocessing
9
+ from pathlib import Path
10
+ import typing as tp
11
+
12
+ import flashy
13
+ import omegaconf
14
+ import torch
15
+ from torch import nn
16
+
17
+ from . import base, builders
18
+ from .. import models, quantization
19
+ from ..utils import checkpoint
20
+ from ..utils.samples.manager import SampleManager
21
+ from ..utils.utils import get_pool_executor
22
+
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class CompressionSolver(base.StandardSolver):
28
+ """Solver for compression task.
29
+
30
+ The compression task combines a set of perceptual and objective losses
31
+ to train an EncodecModel (composed of an encoder-decoder and a quantizer)
32
+ to perform high fidelity audio reconstruction.
33
+ """
34
+ def __init__(self, cfg: omegaconf.DictConfig):
35
+ super().__init__(cfg)
36
+ self.rng: torch.Generator # set at each epoch
37
+ self.adv_losses = builders.get_adversarial_losses(self.cfg)
38
+ self.aux_losses = nn.ModuleDict()
39
+ self.info_losses = nn.ModuleDict()
40
+ assert not cfg.fsdp.use, "FSDP not supported by CompressionSolver."
41
+ loss_weights = dict()
42
+ for loss_name, weight in self.cfg.losses.items():
43
+ if loss_name in ['adv', 'feat']:
44
+ for adv_name, _ in self.adv_losses.items():
45
+ loss_weights[f'{loss_name}_{adv_name}'] = weight
46
+ elif weight > 0:
47
+ self.aux_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
48
+ loss_weights[loss_name] = weight
49
+ else:
50
+ self.info_losses[loss_name] = builders.get_loss(loss_name, self.cfg)
51
+ self.balancer = builders.get_balancer(loss_weights, self.cfg.balancer)
52
+ self.register_stateful('adv_losses')
53
+
54
+ @property
55
+ def best_metric_name(self) -> tp.Optional[str]:
56
+ # best model is the last for the compression model
57
+ return None
58
+
59
+ def build_model(self):
60
+ """Instantiate model and optimizer."""
61
+ # Model and optimizer
62
+ self.model = models.builders.get_compression_model(self.cfg).to(self.device)
63
+ self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
64
+ self.register_stateful('model', 'optimizer')
65
+ self.register_best_state('model')
66
+ self.register_ema('model')
67
+
68
+ def build_dataloaders(self):
69
+ """Instantiate audio dataloaders for each stage."""
70
+ self.dataloaders = builders.get_audio_datasets(self.cfg)
71
+
72
+ def show(self):
73
+ """Show the compression model and employed adversarial loss."""
74
+ self.logger.info(f"Compression model with {self.model.quantizer.total_codebooks} codebooks:")
75
+ self.log_model_summary(self.model)
76
+ self.logger.info("Adversarial loss:")
77
+ self.log_model_summary(self.adv_losses)
78
+ self.logger.info("Auxiliary losses:")
79
+ self.logger.info(self.aux_losses)
80
+ self.logger.info("Info losses:")
81
+ self.logger.info(self.info_losses)
82
+
83
+ def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
84
+ """Perform one training or valid step on a given batch."""
85
+ x = batch.to(self.device)
86
+ y = x.clone()
87
+
88
+ qres = self.model(x)
89
+ assert isinstance(qres, quantization.QuantizedResult)
90
+ y_pred = qres.x
91
+ # Log bandwidth in kb/s
92
+ metrics['bandwidth'] = qres.bandwidth.mean()
93
+
94
+ if self.is_training:
95
+ d_losses: dict = {}
96
+ if len(self.adv_losses) > 0 and torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
97
+ for adv_name, adversary in self.adv_losses.items():
98
+ disc_loss = adversary.train_adv(y_pred, y)
99
+ d_losses[f'd_{adv_name}'] = disc_loss
100
+ metrics['d_loss'] = torch.sum(torch.stack(list(d_losses.values())))
101
+ metrics.update(d_losses)
102
+
103
+ balanced_losses: dict = {}
104
+ other_losses: dict = {}
105
+
106
+ # penalty from quantization
107
+ if qres.penalty is not None and qres.penalty.requires_grad:
108
+ other_losses['penalty'] = qres.penalty # penalty term from the quantizer
109
+
110
+ # adversarial losses
111
+ for adv_name, adversary in self.adv_losses.items():
112
+ adv_loss, feat_loss = adversary(y_pred, y)
113
+ balanced_losses[f'adv_{adv_name}'] = adv_loss
114
+ balanced_losses[f'feat_{adv_name}'] = feat_loss
115
+
116
+ # auxiliary losses
117
+ for loss_name, criterion in self.aux_losses.items():
118
+ loss = criterion(y_pred, y)
119
+ balanced_losses[loss_name] = loss
120
+
121
+ # weighted losses
122
+ metrics.update(balanced_losses)
123
+ metrics.update(other_losses)
124
+ metrics.update(qres.metrics)
125
+
126
+ if self.is_training:
127
+ # backprop losses that are not handled by balancer
128
+ other_loss = torch.tensor(0., device=self.device)
129
+ if 'penalty' in other_losses:
130
+ other_loss += other_losses['penalty']
131
+ if other_loss.requires_grad:
132
+ other_loss.backward(retain_graph=True)
133
+ ratio1 = sum(p.grad.data.norm(p=2).pow(2)
134
+ for p in self.model.parameters() if p.grad is not None)
135
+ assert isinstance(ratio1, torch.Tensor)
136
+ metrics['ratio1'] = ratio1.sqrt()
137
+
138
+ # balancer losses backward, returns effective training loss
139
+ # with effective weights at the current batch.
140
+ metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
141
+ # add metrics corresponding to weight ratios
142
+ metrics.update(self.balancer.metrics)
143
+ ratio2 = sum(p.grad.data.norm(p=2).pow(2)
144
+ for p in self.model.parameters() if p.grad is not None)
145
+ assert isinstance(ratio2, torch.Tensor)
146
+ metrics['ratio2'] = ratio2.sqrt()
147
+
148
+ # optim
149
+ flashy.distrib.sync_model(self.model)
150
+ if self.cfg.optim.max_norm:
151
+ torch.nn.utils.clip_grad_norm_(
152
+ self.model.parameters(), self.cfg.optim.max_norm
153
+ )
154
+ self.optimizer.step()
155
+ self.optimizer.zero_grad()
156
+
157
+ # informative losses only
158
+ info_losses: dict = {}
159
+ with torch.no_grad():
160
+ for loss_name, criterion in self.info_losses.items():
161
+ loss = criterion(y_pred, y)
162
+ info_losses[loss_name] = loss
163
+
164
+ metrics.update(info_losses)
165
+
166
+ # aggregated GAN losses: this is useful to report adv and feat across different adversarial loss setups
167
+ adv_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('adv')]
168
+ if len(adv_losses) > 0:
169
+ metrics['adv'] = torch.sum(torch.stack(adv_losses))
170
+ feat_losses = [loss for loss_name, loss in metrics.items() if loss_name.startswith('feat')]
171
+ if len(feat_losses) > 0:
172
+ metrics['feat'] = torch.sum(torch.stack(feat_losses))
173
+
174
+ return metrics
175
+
176
+ def run_epoch(self):
177
+ # reset random seed at the beginning of the epoch
178
+ self.rng = torch.Generator()
179
+ self.rng.manual_seed(1234 + self.epoch)
180
+ # run epoch
181
+ super().run_epoch()
182
+
183
+ def evaluate(self):
184
+ """Evaluate stage. Runs audio reconstruction evaluation."""
185
+ self.model.eval()
186
+ evaluate_stage_name = str(self.current_stage)
187
+
188
+ loader = self.dataloaders['evaluate']
189
+ updates = len(loader)
190
+ lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
191
+ average = flashy.averager()
192
+
193
+ pendings = []
194
+ ctx = multiprocessing.get_context('spawn')
195
+ with get_pool_executor(self.cfg.evaluate.num_workers, mp_context=ctx) as pool:
196
+ for idx, batch in enumerate(lp):
197
+ x = batch.to(self.device)
198
+ with torch.no_grad():
199
+ qres = self.model(x)
200
+
201
+ y_pred = qres.x.cpu()
202
+ y = batch.cpu() # should already be on CPU but just in case
203
+ pendings.append(pool.submit(evaluate_audio_reconstruction, y_pred, y, self.cfg))
204
+
205
+ metrics_lp = self.log_progress(f'{evaluate_stage_name} metrics', pendings, updates=self.log_updates)
206
+ for pending in metrics_lp:
207
+ metrics = pending.result()
208
+ metrics = average(metrics)
209
+
210
+ metrics = flashy.distrib.average_metrics(metrics, len(loader))
211
+ return metrics
212
+
213
+ def generate(self):
214
+ """Generate stage."""
215
+ self.model.eval()
216
+ sample_manager = SampleManager(self.xp, map_reference_to_sample_id=True)
217
+ generate_stage_name = str(self.current_stage)
218
+
219
+ loader = self.dataloaders['generate']
220
+ updates = len(loader)
221
+ lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
222
+
223
+ for batch in lp:
224
+ reference, _ = batch
225
+ reference = reference.to(self.device)
226
+ with torch.no_grad():
227
+ qres = self.model(reference)
228
+ assert isinstance(qres, quantization.QuantizedResult)
229
+
230
+ reference = reference.cpu()
231
+ estimate = qres.x.cpu()
232
+ sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
233
+
234
+ flashy.distrib.barrier()
235
+
236
+ def load_from_pretrained(self, name: str) -> dict:
237
+ model = models.CompressionModel.get_pretrained(name)
238
+ if isinstance(model, models.DAC):
239
+ raise RuntimeError("Cannot fine tune a DAC model.")
240
+ elif isinstance(model, models.HFEncodecCompressionModel):
241
+ self.logger.warning('Trying to automatically convert a HuggingFace model '
242
+ 'to AudioCraft, this might fail!')
243
+ state = model.model.state_dict()
244
+ new_state = {}
245
+ for k, v in state.items():
246
+ if k.startswith('decoder.layers') and '.conv.' in k and '.block.' not in k:
247
+ # We need to determine if this a convtr or a regular conv.
248
+ layer = int(k.split('.')[2])
249
+ if isinstance(model.model.decoder.layers[layer].conv, torch.nn.ConvTranspose1d):
250
+
251
+ k = k.replace('.conv.', '.convtr.')
252
+ k = k.replace('encoder.layers.', 'encoder.model.')
253
+ k = k.replace('decoder.layers.', 'decoder.model.')
254
+ k = k.replace('conv.', 'conv.conv.')
255
+ k = k.replace('convtr.', 'convtr.convtr.')
256
+ k = k.replace('quantizer.layers.', 'quantizer.vq.layers.')
257
+ k = k.replace('.codebook.', '._codebook.')
258
+ new_state[k] = v
259
+ state = new_state
260
+ elif isinstance(model, models.EncodecModel):
261
+ state = model.state_dict()
262
+ else:
263
+ raise RuntimeError(f"Cannot fine tune model type {type(model)}.")
264
+ return {
265
+ 'best_state': {'model': state}
266
+ }
267
+
268
+ @staticmethod
269
+ def model_from_checkpoint(checkpoint_path: tp.Union[Path, str],
270
+ device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
271
+ """Instantiate a CompressionModel from a given checkpoint path or dora sig.
272
+ This method is a convenient endpoint to load a CompressionModel to use in other solvers.
273
+
274
+ Args:
275
+ checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
276
+ This also supports pre-trained models by using a path of the form //pretrained/NAME.
277
+ See `model_from_pretrained` for a list of supported pretrained models.
278
+ use_ema (bool): Use EMA variant of the model instead of the actual model.
279
+ device (torch.device or str): Device on which the model is loaded.
280
+ """
281
+ checkpoint_path = str(checkpoint_path)
282
+ if checkpoint_path.startswith('//pretrained/'):
283
+ name = checkpoint_path.split('/', 3)[-1]
284
+ return models.CompressionModel.get_pretrained(name, device)
285
+ logger = logging.getLogger(__name__)
286
+ logger.info(f"Loading compression model from checkpoint: {checkpoint_path}")
287
+ _checkpoint_path = checkpoint.resolve_checkpoint_path(checkpoint_path, use_fsdp=False)
288
+ assert _checkpoint_path is not None, f"Could not resolve compression model checkpoint path: {checkpoint_path}"
289
+ state = checkpoint.load_checkpoint(_checkpoint_path)
290
+ assert state is not None and 'xp.cfg' in state, f"Could not load compression model from ckpt: {checkpoint_path}"
291
+ cfg = state['xp.cfg']
292
+ cfg.device = device
293
+ compression_model = models.builders.get_compression_model(cfg).to(device)
294
+ assert compression_model.sample_rate == cfg.sample_rate, "Compression model sample rate should match"
295
+
296
+ assert 'best_state' in state and state['best_state'] != {}
297
+ assert 'exported' not in state, "When loading an exported checkpoint, use the //pretrained/ prefix."
298
+ compression_model.load_state_dict(state['best_state']['model'])
299
+ compression_model.eval()
300
+ logger.info("Compression model loaded!")
301
+ return compression_model
302
+
303
+ @staticmethod
304
+ def wrapped_model_from_checkpoint(cfg: omegaconf.DictConfig,
305
+ checkpoint_path: tp.Union[Path, str],
306
+ device: tp.Union[torch.device, str] = 'cpu') -> models.CompressionModel:
307
+ """Instantiate a wrapped CompressionModel from a given checkpoint path or dora sig.
308
+
309
+ Args:
310
+ cfg (omegaconf.DictConfig): Configuration to read from for wrapped mode.
311
+ checkpoint_path (Path or str): Path to checkpoint or dora sig from where the checkpoint is resolved.
312
+ use_ema (bool): Use EMA variant of the model instead of the actual model.
313
+ device (torch.device or str): Device on which the model is loaded.
314
+ """
315
+ compression_model = CompressionSolver.model_from_checkpoint(checkpoint_path, device)
316
+ compression_model = models.builders.get_wrapped_compression_model(compression_model, cfg)
317
+ return compression_model
318
+
319
+
320
+ def evaluate_audio_reconstruction(y_pred: torch.Tensor, y: torch.Tensor, cfg: omegaconf.DictConfig) -> dict:
321
+ """Audio reconstruction evaluation method that can be conveniently pickled."""
322
+ metrics = {}
323
+ if cfg.evaluate.metrics.visqol:
324
+ visqol = builders.get_visqol(cfg.metrics.visqol)
325
+ metrics['visqol'] = visqol(y_pred, y, cfg.sample_rate)
326
+ sisnr = builders.get_loss('sisnr', cfg)
327
+ metrics['sisnr'] = sisnr(y_pred, y)
328
+ return metrics
audiocraft/solvers/diffusion.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import typing as tp
8
+
9
+ import flashy
10
+ import julius
11
+ import omegaconf
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+ from . import builders
16
+ from . import base
17
+ from .. import models
18
+ from ..modules.diffusion_schedule import NoiseSchedule
19
+ from ..metrics import RelativeVolumeMel
20
+ from ..models.builders import get_processor
21
+ from ..utils.samples.manager import SampleManager
22
+ from ..solvers.compression import CompressionSolver
23
+
24
+
25
+ class PerStageMetrics:
26
+ """Handle prompting the metrics per stage.
27
+ It outputs the metrics per range of diffusion states.
28
+ e.g. avg loss when t in [250, 500]
29
+ """
30
+ def __init__(self, num_steps: int, num_stages: int = 4):
31
+ self.num_steps = num_steps
32
+ self.num_stages = num_stages
33
+
34
+ def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]):
35
+ if type(step) is int:
36
+ stage = int((step / self.num_steps) * self.num_stages)
37
+ return {f"{name}_{stage}": loss for name, loss in losses.items()}
38
+ elif type(step) is torch.Tensor:
39
+ stage_tensor = ((step / self.num_steps) * self.num_stages).long()
40
+ out: tp.Dict[str, float] = {}
41
+ for stage_idx in range(self.num_stages):
42
+ mask = (stage_tensor == stage_idx)
43
+ N = mask.sum()
44
+ stage_out = {}
45
+ if N > 0: # pass if no elements in the stage
46
+ for name, loss in losses.items():
47
+ stage_loss = (mask * loss).sum() / N
48
+ stage_out[f"{name}_{stage_idx}"] = stage_loss
49
+ out = {**out, **stage_out}
50
+ return out
51
+
52
+
53
+ class DataProcess:
54
+ """Apply filtering or resampling.
55
+
56
+ Args:
57
+ initial_sr (int): Initial sample rate.
58
+ target_sr (int): Target sample rate.
59
+ use_resampling: Whether to use resampling or not.
60
+ use_filter (bool):
61
+ n_bands (int): Number of bands to consider.
62
+ idx_band (int):
63
+ device (torch.device or str):
64
+ cutoffs ():
65
+ boost (bool):
66
+ """
67
+ def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False,
68
+ use_filter: bool = False, n_bands: int = 4,
69
+ idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False):
70
+ """Apply filtering or resampling
71
+ Args:
72
+ initial_sr (int): sample rate of the dataset
73
+ target_sr (int): sample rate after resampling
74
+ use_resampling (bool): whether or not performs resampling
75
+ use_filter (bool): when True filter the data to keep only one frequency band
76
+ n_bands (int): Number of bands used
77
+ cuts (none or list): The cutoff frequencies of the band filtering
78
+ if None then we use mel scale bands.
79
+ idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs
80
+ boost (bool): make the data scale match our music dataset.
81
+ """
82
+ assert idx_band < n_bands
83
+ self.idx_band = idx_band
84
+ if use_filter:
85
+ if cutoffs is not None:
86
+ self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device)
87
+ else:
88
+ self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device)
89
+ self.use_filter = use_filter
90
+ self.use_resampling = use_resampling
91
+ self.target_sr = target_sr
92
+ self.initial_sr = initial_sr
93
+ self.boost = boost
94
+
95
+ def process_data(self, x, metric=False):
96
+ if x is None:
97
+ return None
98
+ if self.boost:
99
+ x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4)
100
+ x * 0.22
101
+ if self.use_filter and not metric:
102
+ x = self.filter(x)[self.idx_band]
103
+ if self.use_resampling:
104
+ x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr)
105
+ return x
106
+
107
+ def inverse_process(self, x):
108
+ """Upsampling only."""
109
+ if self.use_resampling:
110
+ x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr)
111
+ return x
112
+
113
+
114
+ class DiffusionSolver(base.StandardSolver):
115
+ """Solver for compression task.
116
+
117
+ The diffusion task allows for MultiBand diffusion model training.
118
+
119
+ Args:
120
+ cfg (DictConfig): Configuration.
121
+ """
122
+ def __init__(self, cfg: omegaconf.DictConfig):
123
+ super().__init__(cfg)
124
+ self.cfg = cfg
125
+ self.device = cfg.device
126
+ self.sample_rate: int = self.cfg.sample_rate
127
+ self.codec_model = CompressionSolver.model_from_checkpoint(
128
+ cfg.compression_model_checkpoint, device=self.device)
129
+
130
+ self.codec_model.set_num_codebooks(cfg.n_q)
131
+ assert self.codec_model.sample_rate == self.cfg.sample_rate, (
132
+ f"Codec model sample rate is {self.codec_model.sample_rate} but "
133
+ f"Solver sample rate is {self.cfg.sample_rate}."
134
+ )
135
+ assert self.codec_model.sample_rate == self.sample_rate, \
136
+ f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \
137
+ "don't match."
138
+
139
+ self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate)
140
+ self.register_stateful('sample_processor')
141
+ self.sample_processor.to(self.device)
142
+
143
+ self.schedule = NoiseSchedule(
144
+ **cfg.schedule, device=self.device, sample_processor=self.sample_processor)
145
+
146
+ self.eval_metric: tp.Optional[torch.nn.Module] = None
147
+
148
+ self.rvm = RelativeVolumeMel()
149
+ self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr,
150
+ use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs,
151
+ use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands,
152
+ idx_band=cfg.filter.idx_band, device=self.device)
153
+
154
+ @property
155
+ def best_metric_name(self) -> tp.Optional[str]:
156
+ if self._current_stage == "evaluate":
157
+ return 'rvm'
158
+ else:
159
+ return 'loss'
160
+
161
+ @torch.no_grad()
162
+ def get_condition(self, wav: torch.Tensor) -> torch.Tensor:
163
+ codes, scale = self.codec_model.encode(wav)
164
+ assert scale is None, "Scaled compression models not supported."
165
+ emb = self.codec_model.decode_latent(codes)
166
+ return emb
167
+
168
+ def build_model(self):
169
+ """Build model and optimizer as well as optional Exponential Moving Average of the model.
170
+ """
171
+ # Model and optimizer
172
+ self.model = models.builders.get_diffusion_model(self.cfg).to(self.device)
173
+ self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
174
+ self.register_stateful('model', 'optimizer')
175
+ self.register_best_state('model')
176
+ self.register_ema('model')
177
+
178
+ def build_dataloaders(self):
179
+ """Build audio dataloaders for each stage."""
180
+ self.dataloaders = builders.get_audio_datasets(self.cfg)
181
+
182
+ def show(self):
183
+ # TODO
184
+ raise NotImplementedError()
185
+
186
+ def run_step(self, idx: int, batch: torch.Tensor, metrics: dict):
187
+ """Perform one training or valid step on a given batch."""
188
+ x = batch.to(self.device)
189
+ loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss
190
+
191
+ condition = self.get_condition(x) # [bs, 128, T/hop, n_emb]
192
+ sample = self.data_processor.process_data(x)
193
+
194
+ input_, target, step = self.schedule.get_training_item(sample,
195
+ tensor_step=self.cfg.schedule.variable_step_batch)
196
+ out = self.model(input_, step, condition=condition).sample
197
+
198
+ base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2))
199
+ reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2))
200
+ loss = base_loss / reference_loss ** self.cfg.loss.norm_power
201
+
202
+ if self.is_training:
203
+ loss.mean().backward()
204
+ flashy.distrib.sync_model(self.model)
205
+ self.optimizer.step()
206
+ self.optimizer.zero_grad()
207
+ metrics = {
208
+ 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(),
209
+ }
210
+ metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step))
211
+ metrics.update({
212
+ 'std_in': input_.std(), 'std_out': out.std()})
213
+ return metrics
214
+
215
+ def run_epoch(self):
216
+ # reset random seed at the beginning of the epoch
217
+ self.rng = torch.Generator()
218
+ self.rng.manual_seed(1234 + self.epoch)
219
+ self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage)
220
+ # run epoch
221
+ super().run_epoch()
222
+
223
+ def evaluate(self):
224
+ """Evaluate stage.
225
+ Runs audio reconstruction evaluation.
226
+ """
227
+ self.model.eval()
228
+ evaluate_stage_name = f'{self.current_stage}'
229
+ loader = self.dataloaders['evaluate']
230
+ updates = len(loader)
231
+ lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates)
232
+
233
+ metrics = {}
234
+ n = 1
235
+ for idx, batch in enumerate(lp):
236
+ x = batch.to(self.device)
237
+ with torch.no_grad():
238
+ y_pred = self.regenerate(x)
239
+
240
+ y_pred = y_pred.cpu()
241
+ y = batch.cpu() # should already be on CPU but just in case
242
+ rvm = self.rvm(y_pred, y)
243
+ lp.update(**rvm)
244
+ if len(metrics) == 0:
245
+ metrics = rvm
246
+ else:
247
+ for key in rvm.keys():
248
+ metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1)
249
+ metrics = flashy.distrib.average_metrics(metrics)
250
+ return metrics
251
+
252
+ @torch.no_grad()
253
+ def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None):
254
+ """Regenerate the given waveform."""
255
+ condition = self.get_condition(wav)
256
+ initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes.
257
+ result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition,
258
+ step_list=step_list)
259
+ result = self.data_processor.inverse_process(result)
260
+ return result
261
+
262
+ def generate(self):
263
+ """Generate stage."""
264
+ sample_manager = SampleManager(self.xp)
265
+ self.model.eval()
266
+ generate_stage_name = f'{self.current_stage}'
267
+
268
+ loader = self.dataloaders['generate']
269
+ updates = len(loader)
270
+ lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
271
+
272
+ for batch in lp:
273
+ reference, _ = batch
274
+ reference = reference.to(self.device)
275
+ estimate = self.regenerate(reference)
276
+ reference = reference.cpu()
277
+ estimate = estimate.cpu()
278
+ sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference)
279
+ flashy.distrib.barrier()
audiocraft/solvers/magnet.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from omegaconf import DictConfig
8
+ from . import builders, musicgen
9
+ from einops import rearrange
10
+ from torch.nn import functional as F
11
+ from ..modules.conditioners import SegmentWithAttributes
12
+
13
+ import torch
14
+ import numpy as np
15
+ import random
16
+ import typing as tp
17
+ import math
18
+ import flashy
19
+
20
+
21
+ class MagnetSolver(musicgen.MusicGenSolver):
22
+ """Solver for MAGNeT - Masked Audio Generation using
23
+ a single Non-autoregressive Transformer https://arxiv.org/abs/2401.04577.
24
+ """
25
+ def __init__(self, cfg: DictConfig):
26
+ super().__init__(cfg)
27
+
28
+ # initialize generation parameters by config
29
+ self.generation_params = {
30
+ 'use_sampling': self.cfg.generate.lm.use_sampling,
31
+ 'temp': self.cfg.generate.lm.temp,
32
+ 'top_k': self.cfg.generate.lm.top_k,
33
+ 'top_p': self.cfg.generate.lm.top_p,
34
+ 'max_cfg_coef': self.cfg.generate.lm.max_cfg_coef,
35
+ 'min_cfg_coef': self.cfg.generate.lm.min_cfg_coef,
36
+ 'decoding_steps': list(self.cfg.generate.lm.decoding_steps),
37
+ 'anneal_temp': self.cfg.generate.lm.anneal_temp,
38
+ 'span_scoring': self.cfg.generate.lm.span_scoring,
39
+ 'span_arrangement': self.cfg.generate.lm.span_arrangement
40
+ }
41
+
42
+ sequence_len = int(cfg.dataset.segment_duration * self.compression_model.frame_rate)
43
+ self.mean_maskrate_to_u = torch.tensor(self._calc_mean_maskrate_to_u_LUT(sequence_len), device=self.device)
44
+ self.ce_per_codebook = [torch.log(torch.tensor(self.compression_model.cardinality, device=self.device))
45
+ for _ in range(cfg.transformer_lm.n_q)]
46
+
47
+ def build_model(self) -> None:
48
+ self.cfg.transformer_lm.segment_duration = self.cfg.dataset.segment_duration
49
+ self.cfg.transformer_lm.span_len = self.cfg.masking.span_len
50
+ assert self.cfg.efficient_attention_backend == "xformers", "MAGNeT v1 models support only xformers backend."
51
+ super().build_model()
52
+
53
+ def _calc_mean_maskrate_to_u_LUT(self, T: int):
54
+ """ Create a Look Up Table (LUT) transforming a discrete masking percentage m in 0,1,...,100 to u,
55
+ the number of overlapping spans of length L to place s.t. the masking rate is approximately m/float(100).
56
+ It first creates the inverse transformation, of the masking rate as function of u,
57
+ using the expression choose(T - L, u) / choose(T, u), where L is the atomic span length used
58
+ during masking. See https://arxiv.org/abs/2401.04577,
59
+ appendix C, for the mean mask rate derivation.
60
+
61
+ We leverage the fact that:
62
+ choose(T - L, u) / choose(T, u) = Prod_{j = 0}^{u - 1}((T - L - j)/(T - j))
63
+ in the provided implementation, in order to avoid overflow.
64
+ Args:
65
+ T (float): Sequence length.
66
+ Returns:
67
+ (List) A LUT transforming m in 0,1,...,100 to u,
68
+ s.t. the masking rate of the span-L mask is approximately m/float(100).
69
+ """
70
+
71
+ L = self.cfg.masking.span_len
72
+
73
+ u2mean = [0.0] # mean mask rate is 0.0 for u = 0
74
+ v = (T - L) / float(T)
75
+ for u in range(1, T):
76
+ u2mean.append(1 - v)
77
+ v *= (T - L - u) / (T - u) # Overflow-safe implementation of choose(T - L, u) / choose(T, u).
78
+
79
+ mean2u = []
80
+ for maskperc in range(101):
81
+ maskrate = maskperc / float(100)
82
+ u = int(np.searchsorted(u2mean, maskrate))
83
+ mean2u.append(u)
84
+
85
+ return mean2u
86
+
87
+ def _non_spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor:
88
+ """ Construct a boolean mask of shape [B, T, 1], with masking rates defined by mask_probs.
89
+ The masked tokens are singletons, placed uniformly at random.
90
+ Args:
91
+ mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,]
92
+ B (int): Batch size.
93
+ T (int): Sequence length.
94
+ device (torch.device): device of the output tensor
95
+ Returns:
96
+ (torch.Tensor): A mask of shape [B, T]
97
+ """
98
+ num_token_masked = (T * mask_probs).round().clamp(min=1)
99
+ batch_randperm = torch.rand((B, T), device=device).argsort(dim=-1)
100
+ return batch_randperm < rearrange(num_token_masked, 'b -> b 1')
101
+
102
+ def _spans_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor:
103
+ """ Construct a spans mask with masking rates defined by mask_probs,
104
+ where the atomic span length ( > 1 ) is defined by cfg.masking.span_len.
105
+ Args:
106
+ mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,]
107
+ B (int): Batch size.
108
+ T (int): Sequence length.
109
+ device (torch.device): device of the output tensor
110
+ Returns:
111
+ (torch.Tensor): A spans mask of shape [B, T]
112
+ """
113
+ rounded_probs = torch.round(100 * mask_probs).long()
114
+ k = self.mean_maskrate_to_u[rounded_probs].clamp(min=1) # k is the number of span starts
115
+
116
+ # sample random span starts
117
+ batch_randperm = torch.rand((B, T), device=device).argsort(dim=-1)
118
+ mask = batch_randperm < rearrange(k, 'b -> b 1')
119
+ B, T = mask.shape
120
+ shifted_mask = mask.clone()
121
+ for _ in range(self.cfg.masking.span_len - 1):
122
+ shifted_mask = torch.concat((torch.full((B, 1), False, device=device), shifted_mask[:, :-1]), dim=1)
123
+ mask = torch.logical_or(mask, shifted_mask)
124
+
125
+ return mask
126
+
127
+ def _get_mask(self, mask_probs: torch.Tensor, B: int, T: int, device: torch.device) -> torch.Tensor:
128
+ """ Construct a boolean mask with masking rates defined by mask_probs, and atomic
129
+ span length defined by cfg.masking.span_len.
130
+ Args:
131
+ mask_probs (torch.Tensor): The desired masking rate per sample, of shape [B,]
132
+ B (int): Batch size.
133
+ T (int): Sequence length.
134
+ device (torch.device): device of the output tensor
135
+ Returns:
136
+ (torch.Tensor): A boolean tensor of shape [B, T]
137
+ """
138
+ if self.cfg.masking.span_len <= 1:
139
+ return self._non_spans_mask(mask_probs, B, T, device)
140
+
141
+ return self._spans_mask(mask_probs, B, T, device)
142
+
143
+ def _compute_cross_entropy_magnet(self, logits: torch.Tensor,
144
+ targets: torch.Tensor, mask: torch.Tensor, stage: torch.Tensor) -> torch.Tensor:
145
+ """ Compute cross entropy between multi-codebook targets and model's logits.
146
+ The cross entropy is computed only on a specific codebook, defined by the stage argument.
147
+ Valid timesteps for each codebook are pulled from the mask, where invalid
148
+ timesteps are set to 0.
149
+
150
+ Args:
151
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
152
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
153
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
154
+ stage (torch.Tensor): The codebook (idx) that is being optimized, as a scalar tensor.
155
+ Returns:
156
+ ce (torch.Tensor): Cross entropy of the codebook that is being optimized.
157
+ """
158
+ assert logits.shape[:-1] == targets.shape
159
+ assert mask.shape == targets.shape
160
+ ce = torch.zeros([], device=targets.device)
161
+ logits_k = logits[:, stage, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
162
+ targets_k = targets[:, stage, ...].contiguous().view(-1) # [B x T]
163
+ mask_k = mask[:, stage, ...].contiguous().view(-1) # [B x T]
164
+
165
+ IGNORE_IDX = -1
166
+ targets_k[~mask_k] = IGNORE_IDX
167
+ q_ce = F.cross_entropy(logits_k, targets_k, ignore_index=IGNORE_IDX)
168
+
169
+ ce += q_ce
170
+ return ce
171
+
172
+ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
173
+ """Perform one training or valid step on a given batch."""
174
+ check_synchronization_points = idx == 1 and self.device == 'cuda'
175
+
176
+ condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
177
+ batch, check_synchronization_points)
178
+
179
+ self.deadlock_detect.update('tokens_and_conditions')
180
+
181
+ if check_synchronization_points:
182
+ torch.cuda.set_sync_debug_mode('warn')
183
+
184
+ B, K, T = audio_tokens.shape
185
+ device = self.device
186
+
187
+ # Choose the stage (codebook idx) for update, uniformly at random.
188
+ stage_ = random.randint(0, K - 1)
189
+ stage = torch.full((1, ), stage_, device=device)
190
+
191
+ # masking
192
+ rand_time = torch.zeros((B,), device=device).float().uniform_(0, 1)
193
+ rand_mask_probs = torch.cos(rand_time * math.pi * 0.5)
194
+
195
+ # stage mask
196
+ stage_mask = self._get_mask(rand_mask_probs, B, T, device) # [B, T]
197
+ stage_mask = stage_mask.unsqueeze(1) # [B, 1, T]
198
+
199
+ # Keep all preceding codebooks.
200
+ mask = torch.full((B, K, T), False, device=device)
201
+ mask[:, stage, :] = stage_mask
202
+
203
+ # Mask all codebooks larger than stage_
204
+ mask_id = self.model.special_token_id
205
+ mask[:, (stage_+1):, :] = torch.full((B, K - stage_ - 1, T), True, device=device)
206
+ input_tokens = torch.where(mask, mask_id, audio_tokens)
207
+
208
+ # Take loss only on the chosen stage, and only on the masked tokens.
209
+ loss_mask = torch.full((B, K, T), False, device=device)
210
+ loss_mask[:, stage, :] = stage_mask
211
+
212
+ with self.autocast:
213
+ model_output = self.model.compute_predictions(input_tokens, [], condition_tensors, stage=stage_)
214
+ logits = model_output.logits
215
+ loss_mask &= padding_mask
216
+ ce = self._compute_cross_entropy_magnet(logits, audio_tokens, loss_mask, stage)
217
+ loss = ce
218
+ self.deadlock_detect.update('loss')
219
+
220
+ if check_synchronization_points:
221
+ torch.cuda.set_sync_debug_mode('default')
222
+
223
+ if self.is_training:
224
+ metrics['lr'] = self.optimizer.param_groups[0]['lr']
225
+ if self.scaler is not None:
226
+ loss = self.scaler.scale(loss)
227
+ self.deadlock_detect.update('scale')
228
+ if self.cfg.fsdp.use:
229
+ loss.backward()
230
+ flashy.distrib.average_tensors(self.model.buffers())
231
+ elif self.cfg.optim.eager_sync:
232
+ with flashy.distrib.eager_sync_model(self.model):
233
+ loss.backward()
234
+ else:
235
+ # this should always be slower but can be useful
236
+ # for weird use cases like multiple backwards.
237
+ loss.backward()
238
+ flashy.distrib.sync_model(self.model)
239
+ self.deadlock_detect.update('backward')
240
+
241
+ if self.scaler is not None:
242
+ self.scaler.unscale_(self.optimizer)
243
+ if self.cfg.optim.max_norm:
244
+ if self.cfg.fsdp.use:
245
+ metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
246
+ else:
247
+ metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
248
+ self.model.parameters(), self.cfg.optim.max_norm
249
+ )
250
+ if self.scaler is None:
251
+ self.optimizer.step()
252
+ else:
253
+ self.scaler.step(self.optimizer)
254
+ self.scaler.update()
255
+ if self.lr_scheduler:
256
+ self.lr_scheduler.step()
257
+ self.optimizer.zero_grad()
258
+ self.deadlock_detect.update('optim')
259
+ if self.scaler is not None:
260
+ scale = self.scaler.get_scale()
261
+ metrics['grad_scale'] = scale
262
+ if not loss.isfinite().all():
263
+ raise RuntimeError("Model probably diverged.")
264
+
265
+ metrics['ce'] = ce
266
+ metrics['ppl'] = torch.exp(ce)
267
+
268
+ return metrics
269
+
270
+
271
+ class AudioMagnetSolver(MagnetSolver):
272
+ """Solver for audio-MAGNeT. A MAGNeT model for sound generation.
273
+
274
+ More information can be found in the MAGNeT model card.
275
+ """
276
+ DATASET_TYPE: builders.DatasetType = builders.DatasetType.SOUND
audiocraft/solvers/musicgen.py ADDED
@@ -0,0 +1,721 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from pathlib import Path
8
+ import time
9
+ import typing as tp
10
+ import warnings
11
+
12
+ import flashy
13
+ import math
14
+ import omegaconf
15
+ import torch
16
+ from torch.nn import functional as F
17
+
18
+ from . import base, builders
19
+ from .compression import CompressionSolver
20
+ from .. import metrics as eval_metrics
21
+ from .. import models
22
+ from ..data.audio_dataset import AudioDataset
23
+ from ..data.music_dataset import MusicDataset, MusicInfo, AudioInfo
24
+ from ..data.audio_utils import normalize_audio
25
+ from ..modules.conditioners import JointEmbedCondition, SegmentWithAttributes, WavCondition
26
+ from ..utils.cache import CachedBatchWriter, CachedBatchLoader
27
+ from ..utils.samples.manager import SampleManager
28
+ from ..utils.utils import get_dataset_from_loader, is_jsonable, warn_once, model_hash
29
+
30
+
31
+ class MusicGenSolver(base.StandardSolver):
32
+ """Solver for MusicGen training task.
33
+
34
+ Used in: https://arxiv.org/abs/2306.05284
35
+ """
36
+ DATASET_TYPE: builders.DatasetType = builders.DatasetType.MUSIC
37
+
38
+ def __init__(self, cfg: omegaconf.DictConfig):
39
+ super().__init__(cfg)
40
+ # easier access to sampling parameters
41
+ self.generation_params = {
42
+ 'use_sampling': self.cfg.generate.lm.use_sampling,
43
+ 'temp': self.cfg.generate.lm.temp,
44
+ 'top_k': self.cfg.generate.lm.top_k,
45
+ 'top_p': self.cfg.generate.lm.top_p,
46
+ }
47
+ self._best_metric_name: tp.Optional[str] = 'ce'
48
+
49
+ self._cached_batch_writer = None
50
+ self._cached_batch_loader = None
51
+ if cfg.cache.path:
52
+ if cfg.cache.write:
53
+ self._cached_batch_writer = CachedBatchWriter(Path(cfg.cache.path))
54
+ if self.cfg.cache.write_num_shards:
55
+ self.logger.warning("Multiple shard cache, best_metric_name will be set to None.")
56
+ self._best_metric_name = None
57
+ else:
58
+ self._cached_batch_loader = CachedBatchLoader(
59
+ Path(cfg.cache.path), cfg.dataset.batch_size, cfg.dataset.num_workers,
60
+ min_length=self.cfg.optim.updates_per_epoch or 1)
61
+ self.dataloaders['original_train'] = self.dataloaders['train']
62
+ self.dataloaders['train'] = self._cached_batch_loader # type: ignore
63
+
64
+ @staticmethod
65
+ def get_eval_solver_from_sig(sig: str, dtype: tp.Optional[str] = None,
66
+ device: tp.Optional[str] = None, autocast: bool = True,
67
+ batch_size: tp.Optional[int] = None,
68
+ override_cfg: tp.Optional[tp.Union[dict, omegaconf.DictConfig]] = None,
69
+ **kwargs):
70
+ """Mostly a convenience function around magma.train.get_solver_from_sig,
71
+ populating all the proper param, deactivating EMA, FSDP, loading the best state,
72
+ basically all you need to get a solver ready to "play" with in single GPU mode
73
+ and with minimal memory overhead.
74
+
75
+ Args:
76
+ sig (str): signature to load.
77
+ dtype (str or None): potential dtype, as a string, i.e. 'float16'.
78
+ device (str or None): potential device, as a string, i.e. 'cuda'.
79
+ override_cfg (dict or omegaconf.DictConfig or None): potential device, as a string, i.e. 'cuda'.
80
+ """
81
+ from audiocraft import train
82
+ our_override_cfg: tp.Dict[str, tp.Any] = {'optim': {'ema': {'use': False}}}
83
+ our_override_cfg['autocast'] = autocast
84
+ if dtype is not None:
85
+ our_override_cfg['dtype'] = dtype
86
+ if device is not None:
87
+ our_override_cfg['device'] = device
88
+ if batch_size is not None:
89
+ our_override_cfg['dataset'] = {'batch_size': batch_size}
90
+ if override_cfg is None:
91
+ override_cfg = {}
92
+ override_cfg = omegaconf.OmegaConf.merge(
93
+ omegaconf.DictConfig(override_cfg), omegaconf.DictConfig(our_override_cfg)) # type: ignore
94
+ solver = train.get_solver_from_sig(
95
+ sig, override_cfg=override_cfg,
96
+ load_best=True, disable_fsdp=True,
97
+ ignore_state_keys=['optimizer', 'ema'], **kwargs)
98
+ solver.model.eval()
99
+ return solver
100
+
101
+ def get_formatter(self, stage_name: str) -> flashy.Formatter:
102
+ return flashy.Formatter({
103
+ 'lr': '.2E',
104
+ 'ce': '.3f',
105
+ 'ppl': '.3f',
106
+ 'grad_norm': '.3E',
107
+ }, exclude_keys=['ce_q*', 'ppl_q*'])
108
+
109
+ @property
110
+ def best_metric_name(self) -> tp.Optional[str]:
111
+ return self._best_metric_name
112
+
113
+ def build_model(self) -> None:
114
+ """Instantiate models and optimizer."""
115
+ # we can potentially not use all quantizers with which the EnCodec model was trained
116
+ # (e.g. we trained the model with quantizers dropout)
117
+ self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
118
+ self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
119
+ assert self.compression_model.sample_rate == self.cfg.sample_rate, (
120
+ f"Compression model sample rate is {self.compression_model.sample_rate} but "
121
+ f"Solver sample rate is {self.cfg.sample_rate}."
122
+ )
123
+ # ensure we have matching configuration between LM and compression model
124
+ assert self.cfg.transformer_lm.card == self.compression_model.cardinality, (
125
+ "Cardinalities of the LM and compression model don't match: ",
126
+ f"LM cardinality is {self.cfg.transformer_lm.card} vs ",
127
+ f"compression model cardinality is {self.compression_model.cardinality}"
128
+ )
129
+ assert self.cfg.transformer_lm.n_q == self.compression_model.num_codebooks, (
130
+ "Numbers of codebooks of the LM and compression models don't match: ",
131
+ f"LM number of codebooks is {self.cfg.transformer_lm.n_q} vs ",
132
+ f"compression model numer of codebooks is {self.compression_model.num_codebooks}"
133
+ )
134
+ self.logger.info("Compression model has %d codebooks with %d cardinality, and a framerate of %d",
135
+ self.compression_model.num_codebooks, self.compression_model.cardinality,
136
+ self.compression_model.frame_rate)
137
+ # instantiate LM model
138
+ self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)
139
+ if self.cfg.fsdp.use:
140
+ assert not self.cfg.autocast, "Cannot use autocast with fsdp"
141
+ self.model = self.wrap_with_fsdp(self.model)
142
+ self.register_ema('model')
143
+ # initialize optimization
144
+ self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim)
145
+ self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
146
+ self.register_stateful('model', 'optimizer', 'lr_scheduler')
147
+ self.register_best_state('model')
148
+ self.autocast_dtype = {
149
+ 'float16': torch.float16, 'bfloat16': torch.bfloat16
150
+ }[self.cfg.autocast_dtype]
151
+ self.scaler: tp.Optional[torch.cuda.amp.GradScaler] = None
152
+ if self.cfg.fsdp.use:
153
+ need_scaler = self.cfg.fsdp.param_dtype == 'float16'
154
+ else:
155
+ need_scaler = self.cfg.autocast and self.autocast_dtype is torch.float16
156
+ if need_scaler:
157
+ if self.cfg.fsdp.use:
158
+ from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
159
+ self.scaler = ShardedGradScaler() # type: ignore
160
+ else:
161
+ self.scaler = torch.cuda.amp.GradScaler()
162
+ self.register_stateful('scaler')
163
+
164
+ def build_dataloaders(self) -> None:
165
+ """Instantiate audio dataloaders for each stage."""
166
+ self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
167
+
168
+ def show(self) -> None:
169
+ """Show the compression model and LM model."""
170
+ self.logger.info("Compression model:")
171
+ self.log_model_summary(self.compression_model)
172
+ self.logger.info("LM model:")
173
+ self.log_model_summary(self.model)
174
+
175
+ def load_state_dict(self, state: dict) -> None:
176
+ if 'condition_provider' in state:
177
+ model_state = state['model']
178
+ condition_provider_state = state.pop('condition_provider')
179
+ prefix = 'condition_provider.'
180
+ for key, value in condition_provider_state.items():
181
+ key = prefix + key
182
+ assert key not in model_state
183
+ model_state[key] = value
184
+ if 'compression_model' in state:
185
+ # We used to store the `compression_model` state in the checkpoint, however
186
+ # this is in general not needed, as the compression model should always be readable
187
+ # from the original `cfg.compression_model_checkpoint` location.
188
+ compression_model_state = state.pop('compression_model')
189
+ before_hash = model_hash(self.compression_model)
190
+ self.compression_model.load_state_dict(compression_model_state)
191
+ after_hash = model_hash(self.compression_model)
192
+ if before_hash != after_hash:
193
+ raise RuntimeError(
194
+ "The compression model state inside the checkpoint is different"
195
+ " from the one obtained from compression_model_checkpoint..."
196
+ "We do not support altering the compression model inside the LM "
197
+ "checkpoint as parts of the code, in particular for running eval post-training "
198
+ "will use the compression_model_checkpoint as the source of truth.")
199
+
200
+ super().load_state_dict(state)
201
+
202
+ def load_from_pretrained(self, name: str):
203
+ # TODO: support native HF versions of MusicGen.
204
+ lm_pkg = models.loaders.load_lm_model_ckpt(name)
205
+ state: dict = {
206
+ 'best_state': {
207
+ 'model': lm_pkg['best_state'],
208
+ },
209
+ }
210
+ return state
211
+
212
+ def _compute_cross_entropy(
213
+ self, logits: torch.Tensor, targets: torch.Tensor, mask: torch.Tensor
214
+ ) -> tp.Tuple[torch.Tensor, tp.List[torch.Tensor]]:
215
+ """Compute cross entropy between multi-codebook targets and model's logits.
216
+ The cross entropy is computed per codebook to provide codebook-level cross entropy.
217
+ Valid timesteps for each of the codebook are pulled from the mask, where invalid
218
+ timesteps are set to 0.
219
+
220
+ Args:
221
+ logits (torch.Tensor): Model's logits of shape [B, K, T, card].
222
+ targets (torch.Tensor): Target codes, of shape [B, K, T].
223
+ mask (torch.Tensor): Mask for valid target codes, of shape [B, K, T].
224
+ Returns:
225
+ ce (torch.Tensor): Cross entropy averaged over the codebooks
226
+ ce_per_codebook (list of torch.Tensor): Cross entropy per codebook (detached).
227
+ """
228
+ B, K, T = targets.shape
229
+ assert logits.shape[:-1] == targets.shape
230
+ assert mask.shape == targets.shape
231
+ ce = torch.zeros([], device=targets.device)
232
+ ce_per_codebook: tp.List[torch.Tensor] = []
233
+ for k in range(K):
234
+ logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1)) # [B x T, card]
235
+ targets_k = targets[:, k, ...].contiguous().view(-1) # [B x T]
236
+ mask_k = mask[:, k, ...].contiguous().view(-1) # [B x T]
237
+ ce_targets = targets_k[mask_k]
238
+ ce_logits = logits_k[mask_k]
239
+ q_ce = F.cross_entropy(ce_logits, ce_targets)
240
+ ce += q_ce
241
+ ce_per_codebook.append(q_ce.detach())
242
+ # average cross entropy across codebooks
243
+ ce = ce / K
244
+ return ce, ce_per_codebook
245
+
246
+ def _prepare_tokens_and_attributes(
247
+ self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
248
+ check_synchronization_points: bool = False
249
+ ) -> tp.Tuple[dict, torch.Tensor, torch.Tensor]:
250
+ """Prepare input batchs for language model training.
251
+
252
+ Args:
253
+ batch (tuple[torch.Tensor, list[SegmentWithAttributes]]): Input batch with audio tensor of shape [B, C, T]
254
+ and corresponding metadata as SegmentWithAttributes (with B items).
255
+ check_synchronization_points (bool): Whether to check for synchronization points slowing down training.
256
+ Returns:
257
+ Condition tensors (dict[str, any]): Preprocessed condition attributes.
258
+ Tokens (torch.Tensor): Audio tokens from compression model, of shape [B, K, T_s],
259
+ with B the batch size, K the number of codebooks, T_s the token timesteps.
260
+ Padding mask (torch.Tensor): Mask with valid positions in the tokens tensor, of shape [B, K, T_s].
261
+ """
262
+ if self.model.training:
263
+ warnings.warn(
264
+ "Up to version 1.0.1, the _prepare_tokens_and_attributes was evaluated with `torch.no_grad()`. "
265
+ "This is inconsistent with how model were trained in the MusicGen paper. We removed the "
266
+ "`torch.no_grad()` in version 1.1.0. Small changes to the final performance are expected. "
267
+ "Really sorry about that.")
268
+ if self._cached_batch_loader is None or self.current_stage != "train":
269
+ audio, infos = batch
270
+ audio = audio.to(self.device)
271
+ audio_tokens = None
272
+ assert audio.size(0) == len(infos), (
273
+ f"Mismatch between number of items in audio batch ({audio.size(0)})",
274
+ f" and in metadata ({len(infos)})"
275
+ )
276
+ else:
277
+ audio = None
278
+ # In that case the batch will be a tuple coming from the _cached_batch_writer bit below.
279
+ infos, = batch # type: ignore
280
+ assert all([isinstance(info, AudioInfo) for info in infos])
281
+ assert all([info.audio_tokens is not None for info in infos]) # type: ignore
282
+ audio_tokens = torch.stack([info.audio_tokens for info in infos]).to(self.device) # type: ignore
283
+ audio_tokens = audio_tokens.long()
284
+ for info in infos:
285
+ if isinstance(info, MusicInfo):
286
+ # Careful here, if you want to use this condition_wav (e.b. chroma conditioning),
287
+ # then you must be using the chroma cache! otherwise the code will try
288
+ # to use this segment and fail (by that I mean you will see NaN everywhere).
289
+ info.self_wav = WavCondition(
290
+ torch.full([1, info.channels, info.total_frames], float('NaN')),
291
+ length=torch.tensor([info.n_frames]),
292
+ sample_rate=[info.sample_rate],
293
+ path=[info.meta.path],
294
+ seek_time=[info.seek_time])
295
+ dataset = get_dataset_from_loader(self.dataloaders['original_train'])
296
+ assert isinstance(dataset, MusicDataset), type(dataset)
297
+ if dataset.paraphraser is not None and info.description is not None:
298
+ # Hackingly reapplying paraphraser when using cache.
299
+ info.description = dataset.paraphraser.sample_paraphrase(
300
+ info.meta.path, info.description)
301
+ # prepare attributes
302
+ attributes = [info.to_condition_attributes() for info in infos]
303
+ attributes = self.model.cfg_dropout(attributes)
304
+ attributes = self.model.att_dropout(attributes)
305
+ tokenized = self.model.condition_provider.tokenize(attributes)
306
+
307
+ # Now we should be synchronization free.
308
+ if self.device == "cuda" and check_synchronization_points:
309
+ torch.cuda.set_sync_debug_mode("warn")
310
+
311
+ if audio_tokens is None:
312
+ with torch.no_grad():
313
+ audio_tokens, scale = self.compression_model.encode(audio)
314
+ assert scale is None, "Scaled compression model not supported with LM."
315
+
316
+ with self.autocast:
317
+ condition_tensors = self.model.condition_provider(tokenized)
318
+
319
+ # create a padding mask to hold valid vs invalid positions
320
+ padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool, device=audio_tokens.device)
321
+ # replace encodec tokens from padded audio with special_token_id
322
+ if self.cfg.tokens.padding_with_special_token:
323
+ audio_tokens = audio_tokens.clone()
324
+ padding_mask = padding_mask.clone()
325
+ token_sample_rate = self.compression_model.frame_rate
326
+ B, K, T_s = audio_tokens.shape
327
+ for i in range(B):
328
+ n_samples = infos[i].n_frames
329
+ audio_sample_rate = infos[i].sample_rate
330
+ # take the last token generated from actual audio frames (non-padded audio)
331
+ valid_tokens = math.floor(float(n_samples) / audio_sample_rate * token_sample_rate)
332
+ audio_tokens[i, :, valid_tokens:] = self.model.special_token_id
333
+ padding_mask[i, :, valid_tokens:] = 0
334
+
335
+ if self.device == "cuda" and check_synchronization_points:
336
+ torch.cuda.set_sync_debug_mode("default")
337
+
338
+ if self._cached_batch_writer is not None and self.current_stage == 'train':
339
+ assert self._cached_batch_loader is None
340
+ assert audio_tokens is not None
341
+ for info, one_audio_tokens in zip(infos, audio_tokens):
342
+ assert isinstance(info, AudioInfo)
343
+ if isinstance(info, MusicInfo):
344
+ assert not info.joint_embed, "joint_embed and cache not supported yet."
345
+ info.self_wav = None
346
+ assert one_audio_tokens.max() < 2**15, one_audio_tokens.max().item()
347
+ info.audio_tokens = one_audio_tokens.short().cpu()
348
+ self._cached_batch_writer.save(infos)
349
+
350
+ return condition_tensors, audio_tokens, padding_mask
351
+
352
+ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]], metrics: dict) -> dict:
353
+ """Perform one training or valid step on a given batch."""
354
+ check_synchronization_points = idx == 1 and self.device == 'cuda'
355
+
356
+ condition_tensors, audio_tokens, padding_mask = self._prepare_tokens_and_attributes(
357
+ batch, check_synchronization_points)
358
+
359
+ self.deadlock_detect.update('tokens_and_conditions')
360
+
361
+ if check_synchronization_points:
362
+ torch.cuda.set_sync_debug_mode('warn')
363
+
364
+ with self.autocast:
365
+ model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore
366
+ logits = model_output.logits
367
+ mask = padding_mask & model_output.mask
368
+ ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
369
+ loss = ce
370
+ self.deadlock_detect.update('loss')
371
+
372
+ if check_synchronization_points:
373
+ torch.cuda.set_sync_debug_mode('default')
374
+
375
+ if self.is_training:
376
+ metrics['lr'] = self.optimizer.param_groups[0]['lr']
377
+ if self.scaler is not None:
378
+ loss = self.scaler.scale(loss)
379
+ self.deadlock_detect.update('scale')
380
+ if self.cfg.fsdp.use:
381
+ loss.backward()
382
+ flashy.distrib.average_tensors(self.model.buffers())
383
+ elif self.cfg.optim.eager_sync:
384
+ with flashy.distrib.eager_sync_model(self.model):
385
+ loss.backward()
386
+ else:
387
+ # this should always be slower but can be useful
388
+ # for weird use cases like multiple backwards.
389
+ loss.backward()
390
+ flashy.distrib.sync_model(self.model)
391
+ self.deadlock_detect.update('backward')
392
+
393
+ if self.scaler is not None:
394
+ self.scaler.unscale_(self.optimizer)
395
+ if self.cfg.optim.max_norm:
396
+ if self.cfg.fsdp.use:
397
+ metrics['grad_norm'] = self.model.clip_grad_norm_(self.cfg.optim.max_norm) # type: ignore
398
+ else:
399
+ metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
400
+ self.model.parameters(), self.cfg.optim.max_norm
401
+ )
402
+ if self.scaler is None:
403
+ self.optimizer.step()
404
+ else:
405
+ self.scaler.step(self.optimizer)
406
+ self.scaler.update()
407
+ if self.lr_scheduler:
408
+ self.lr_scheduler.step()
409
+ self.optimizer.zero_grad()
410
+ self.deadlock_detect.update('optim')
411
+ if self.scaler is not None:
412
+ scale = self.scaler.get_scale()
413
+ metrics['grad_scale'] = scale
414
+ if not loss.isfinite().all():
415
+ raise RuntimeError("Model probably diverged.")
416
+
417
+ metrics['ce'] = ce
418
+ metrics['ppl'] = torch.exp(ce)
419
+ for k, ce_q in enumerate(ce_per_codebook):
420
+ metrics[f'ce_q{k + 1}'] = ce_q
421
+ metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
422
+
423
+ return metrics
424
+
425
+ @torch.no_grad()
426
+ def run_generate_step(self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
427
+ gen_duration: float, prompt_duration: tp.Optional[float] = None,
428
+ remove_prompt: bool = False,
429
+ **generation_params) -> dict:
430
+ """Run generate step on a batch of optional audio tensor and corresponding attributes.
431
+
432
+ Args:
433
+ batch (tuple[torch.Tensor, list[SegmentWithAttributes]]):
434
+ use_prompt (bool): Whether to do audio continuation generation with prompt from audio batch.
435
+ gen_duration (float): Target audio duration for the generation.
436
+ prompt_duration (float, optional): Duration for the audio prompt to use for continuation.
437
+ remove_prompt (bool, optional): Whether to remove the prompt from the generated audio.
438
+ generation_params: Additional generation parameters.
439
+ Returns:
440
+ gen_outputs (dict): Generation outputs, consisting in audio, audio tokens from both the generation
441
+ and the prompt along with additional information.
442
+ """
443
+ bench_start = time.time()
444
+ audio, meta = batch
445
+ assert audio.size(0) == len(meta), (
446
+ f"Mismatch between number of items in audio batch ({audio.size(0)})",
447
+ f" and in metadata ({len(meta)})"
448
+ )
449
+ # prepare attributes
450
+ attributes = [x.to_condition_attributes() for x in meta]
451
+ # TODO: Add dropout for chroma?
452
+
453
+ # prepare audio prompt
454
+ if prompt_duration is None:
455
+ prompt_audio = None
456
+ else:
457
+ assert prompt_duration < gen_duration, "Prompt duration must be lower than target generation duration"
458
+ prompt_audio_frames = int(prompt_duration * self.compression_model.sample_rate)
459
+ prompt_audio = audio[..., :prompt_audio_frames]
460
+
461
+ # get audio tokens from compression model
462
+ if prompt_audio is None or prompt_audio.nelement() == 0:
463
+ num_samples = len(attributes)
464
+ prompt_tokens = None
465
+ else:
466
+ num_samples = None
467
+ prompt_audio = prompt_audio.to(self.device)
468
+ prompt_tokens, scale = self.compression_model.encode(prompt_audio)
469
+ assert scale is None, "Compression model in MusicGen should not require rescaling."
470
+
471
+ # generate by sampling from the LM
472
+ with self.autocast:
473
+ total_gen_len = math.ceil(gen_duration * self.compression_model.frame_rate)
474
+ gen_tokens = self.model.generate(
475
+ prompt_tokens, attributes, max_gen_len=total_gen_len,
476
+ num_samples=num_samples, **self.generation_params)
477
+
478
+ # generate audio from tokens
479
+ assert gen_tokens.dim() == 3
480
+ gen_audio = self.compression_model.decode(gen_tokens, None)
481
+
482
+ bench_end = time.time()
483
+ gen_outputs = {
484
+ 'rtf': (bench_end - bench_start) / gen_duration,
485
+ 'ref_audio': audio,
486
+ 'gen_audio': gen_audio,
487
+ 'gen_tokens': gen_tokens,
488
+ 'prompt_audio': prompt_audio,
489
+ 'prompt_tokens': prompt_tokens,
490
+ }
491
+ return gen_outputs
492
+
493
+ def generate_audio(self) -> dict:
494
+ """Audio generation stage."""
495
+ generate_stage_name = f'{self.current_stage}'
496
+ sample_manager = SampleManager(self.xp)
497
+ self.logger.info(f"Generating samples in {sample_manager.base_folder}")
498
+ loader = self.dataloaders['generate']
499
+ updates = len(loader)
500
+ lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates)
501
+
502
+ dataset = get_dataset_from_loader(loader)
503
+ dataset_duration = dataset.segment_duration
504
+ assert dataset_duration is not None
505
+ assert isinstance(dataset, AudioDataset)
506
+ target_duration = self.cfg.generate.lm.gen_duration
507
+ prompt_duration = self.cfg.generate.lm.prompt_duration
508
+ if target_duration is None:
509
+ target_duration = dataset_duration
510
+ if prompt_duration is None:
511
+ prompt_duration = dataset_duration / 4
512
+ assert prompt_duration < dataset_duration, (
513
+ f"Specified prompt duration ({prompt_duration}s) is longer",
514
+ f" than reference audio duration ({dataset_duration}s)"
515
+ )
516
+
517
+ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
518
+ hydrated_conditions = []
519
+ for sample in [x.to_condition_attributes() for x in meta]:
520
+ cond_dict = {}
521
+ for cond_type in sample.__annotations__.keys():
522
+ for cond_key, cond_val in getattr(sample, cond_type).items():
523
+ if cond_key not in self.model.condition_provider.conditioners.keys():
524
+ continue
525
+ if is_jsonable(cond_val):
526
+ cond_dict[cond_key] = cond_val
527
+ elif isinstance(cond_val, WavCondition):
528
+ cond_dict[cond_key] = cond_val.path
529
+ elif isinstance(cond_val, JointEmbedCondition):
530
+ cond_dict[cond_key] = cond_val.text # only support text at inference for now
531
+ else:
532
+ # if we reached this point, it is not clear how to log the condition
533
+ # so we just log the type.
534
+ cond_dict[cond_key] = str(type(cond_val))
535
+ continue
536
+ hydrated_conditions.append(cond_dict)
537
+ return hydrated_conditions
538
+
539
+ metrics: dict = {}
540
+ average = flashy.averager()
541
+ for batch in lp:
542
+ audio, meta = batch
543
+ # metadata for sample manager
544
+ hydrated_conditions = get_hydrated_conditions(meta)
545
+ sample_generation_params = {
546
+ **{f'classifier_free_guidance_{k}': v for k, v in self.cfg.classifier_free_guidance.items()},
547
+ **self.generation_params
548
+ }
549
+ if self.cfg.generate.lm.unprompted_samples:
550
+ if self.cfg.generate.lm.gen_gt_samples:
551
+ # get the ground truth instead of generation
552
+ self.logger.warn(
553
+ "Use ground truth instead of audio generation as generate.lm.gen_gt_samples=true")
554
+ gen_unprompted_audio = audio
555
+ rtf = 1.
556
+ else:
557
+ gen_unprompted_outputs = self.run_generate_step(
558
+ batch, gen_duration=target_duration, prompt_duration=None,
559
+ **self.generation_params)
560
+ gen_unprompted_audio = gen_unprompted_outputs['gen_audio'].cpu()
561
+ rtf = gen_unprompted_outputs['rtf']
562
+ sample_manager.add_samples(
563
+ gen_unprompted_audio, self.epoch, hydrated_conditions,
564
+ ground_truth_wavs=audio, generation_args=sample_generation_params)
565
+
566
+ if self.cfg.generate.lm.prompted_samples:
567
+ gen_outputs = self.run_generate_step(
568
+ batch, gen_duration=target_duration, prompt_duration=prompt_duration,
569
+ **self.generation_params)
570
+ gen_audio = gen_outputs['gen_audio'].cpu()
571
+ prompt_audio = gen_outputs['prompt_audio'].cpu()
572
+ sample_manager.add_samples(
573
+ gen_audio, self.epoch, hydrated_conditions,
574
+ prompt_wavs=prompt_audio, ground_truth_wavs=audio,
575
+ generation_args=sample_generation_params)
576
+
577
+ metrics['rtf'] = rtf
578
+ metrics = average(metrics)
579
+
580
+ flashy.distrib.barrier()
581
+ return metrics
582
+
583
+ def generate(self) -> dict:
584
+ """Generate stage."""
585
+ self.model.eval()
586
+ with torch.no_grad():
587
+ return self.generate_audio()
588
+
589
+ def run_epoch(self):
590
+ if self.cfg.cache.write:
591
+ if ((self.epoch - 1) % self.cfg.cache.write_num_shards) != self.cfg.cache.write_shard:
592
+ return
593
+ super().run_epoch()
594
+
595
+ def train(self):
596
+ """Train stage.
597
+ """
598
+ if self._cached_batch_writer is not None:
599
+ self._cached_batch_writer.start_epoch(self.epoch)
600
+ if self._cached_batch_loader is None:
601
+ dataset = get_dataset_from_loader(self.dataloaders['train'])
602
+ assert isinstance(dataset, AudioDataset)
603
+ dataset.current_epoch = self.epoch
604
+ else:
605
+ self._cached_batch_loader.start_epoch(self.epoch)
606
+ return super().train()
607
+
608
+ def evaluate_audio_generation(self) -> dict:
609
+ """Evaluate audio generation with off-the-shelf metrics."""
610
+ evaluate_stage_name = f'{self.current_stage}_generation'
611
+ # instantiate evaluation metrics, if at least one metric is defined, run audio generation evaluation
612
+ fad: tp.Optional[eval_metrics.FrechetAudioDistanceMetric] = None
613
+ kldiv: tp.Optional[eval_metrics.KLDivergenceMetric] = None
614
+ text_consistency: tp.Optional[eval_metrics.TextConsistencyMetric] = None
615
+ chroma_cosine: tp.Optional[eval_metrics.ChromaCosineSimilarityMetric] = None
616
+ should_run_eval = False
617
+ eval_chroma_wavs: tp.Optional[torch.Tensor] = None
618
+ if self.cfg.evaluate.metrics.fad:
619
+ fad = builders.get_fad(self.cfg.metrics.fad).to(self.device)
620
+ should_run_eval = True
621
+ if self.cfg.evaluate.metrics.kld:
622
+ kldiv = builders.get_kldiv(self.cfg.metrics.kld).to(self.device)
623
+ should_run_eval = True
624
+ if self.cfg.evaluate.metrics.text_consistency:
625
+ text_consistency = builders.get_text_consistency(self.cfg.metrics.text_consistency).to(self.device)
626
+ should_run_eval = True
627
+ if self.cfg.evaluate.metrics.chroma_cosine:
628
+ chroma_cosine = builders.get_chroma_cosine_similarity(self.cfg.metrics.chroma_cosine).to(self.device)
629
+ # if we have predefind wavs for chroma we should purge them for computing the cosine metric
630
+ has_predefined_eval_chromas = 'self_wav' in self.model.condition_provider.conditioners and \
631
+ self.model.condition_provider.conditioners['self_wav'].has_eval_wavs()
632
+ if has_predefined_eval_chromas:
633
+ warn_once(self.logger, "Attempting to run cosine eval for config with pre-defined eval chromas! "
634
+ 'Resetting eval chromas to None for evaluation.')
635
+ eval_chroma_wavs = self.model.condition_provider.conditioners.self_wav.eval_wavs # type: ignore
636
+ self.model.condition_provider.conditioners.self_wav.reset_eval_wavs(None) # type: ignore
637
+ should_run_eval = True
638
+
639
+ def get_compressed_audio(audio: torch.Tensor) -> torch.Tensor:
640
+ audio_tokens, scale = self.compression_model.encode(audio.to(self.device))
641
+ compressed_audio = self.compression_model.decode(audio_tokens, scale)
642
+ return compressed_audio[..., :audio.shape[-1]]
643
+
644
+ metrics: dict = {}
645
+ if should_run_eval:
646
+ loader = self.dataloaders['evaluate']
647
+ updates = len(loader)
648
+ lp = self.log_progress(f'{evaluate_stage_name} inference', loader, total=updates, updates=self.log_updates)
649
+ average = flashy.averager()
650
+ dataset = get_dataset_from_loader(loader)
651
+ assert isinstance(dataset, AudioDataset)
652
+ self.logger.info(f"Computing evaluation metrics on {len(dataset)} samples")
653
+
654
+ for idx, batch in enumerate(lp):
655
+ audio, meta = batch
656
+ assert all([self.cfg.sample_rate == m.sample_rate for m in meta])
657
+
658
+ target_duration = audio.shape[-1] / self.cfg.sample_rate
659
+ if self.cfg.evaluate.fixed_generation_duration:
660
+ target_duration = self.cfg.evaluate.fixed_generation_duration
661
+
662
+ gen_outputs = self.run_generate_step(
663
+ batch, gen_duration=target_duration,
664
+ **self.generation_params
665
+ )
666
+ y_pred = gen_outputs['gen_audio'].detach()
667
+ y_pred = y_pred[..., :audio.shape[-1]]
668
+
669
+ normalize_kwargs = dict(self.cfg.generate.audio)
670
+ normalize_kwargs.pop('format', None)
671
+ y_pred = torch.stack([normalize_audio(w, **normalize_kwargs) for w in y_pred], dim=0).cpu()
672
+ y = audio.cpu() # should already be on CPU but just in case
673
+ sizes = torch.tensor([m.n_frames for m in meta]) # actual sizes without padding
674
+ sample_rates = torch.tensor([m.sample_rate for m in meta]) # sample rates for audio samples
675
+ audio_stems = [Path(m.meta.path).stem + f"_{m.seek_time}" for m in meta]
676
+
677
+ if fad is not None:
678
+ if self.cfg.metrics.fad.use_gt:
679
+ y_pred = get_compressed_audio(y).cpu()
680
+ fad.update(y_pred, y, sizes, sample_rates, audio_stems)
681
+ if kldiv is not None:
682
+ if self.cfg.metrics.kld.use_gt:
683
+ y_pred = get_compressed_audio(y).cpu()
684
+ kldiv.update(y_pred, y, sizes, sample_rates)
685
+ if text_consistency is not None:
686
+ texts = [m.description for m in meta]
687
+ if self.cfg.metrics.text_consistency.use_gt:
688
+ y_pred = y
689
+ text_consistency.update(y_pred, texts, sizes, sample_rates)
690
+ if chroma_cosine is not None:
691
+ if self.cfg.metrics.chroma_cosine.use_gt:
692
+ y_pred = get_compressed_audio(y).cpu()
693
+ chroma_cosine.update(y_pred, y, sizes, sample_rates)
694
+ # restore chroma conditioner's eval chroma wavs
695
+ if eval_chroma_wavs is not None:
696
+ self.model.condition_provider.conditioners['self_wav'].reset_eval_wavs(eval_chroma_wavs)
697
+
698
+ flashy.distrib.barrier()
699
+ if fad is not None:
700
+ metrics['fad'] = fad.compute()
701
+ if kldiv is not None:
702
+ kld_metrics = kldiv.compute()
703
+ metrics.update(kld_metrics)
704
+ if text_consistency is not None:
705
+ metrics['text_consistency'] = text_consistency.compute()
706
+ if chroma_cosine is not None:
707
+ metrics['chroma_cosine'] = chroma_cosine.compute()
708
+ metrics = average(metrics)
709
+ metrics = flashy.distrib.average_metrics(metrics, len(loader))
710
+
711
+ return metrics
712
+
713
+ def evaluate(self) -> dict:
714
+ """Evaluate stage."""
715
+ self.model.eval()
716
+ with torch.no_grad():
717
+ metrics: dict = {}
718
+ if self.cfg.evaluate.metrics.base:
719
+ metrics.update(self.common_train_valid('evaluate'))
720
+ gen_metrics = self.evaluate_audio_generation()
721
+ return {**metrics, **gen_metrics}