unpairedelectron07 commited on
Commit
982b37b
·
verified ·
1 Parent(s): 28d2e51

Upload 11 files

Browse files
audiocraft/models/audiogen.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Main model for using AudioGen. This will combine all the required components
9
+ and provide easy access to the generation API.
10
+ """
11
+
12
+ import typing as tp
13
+
14
+ import torch
15
+
16
+ from .encodec import CompressionModel
17
+ from .genmodel import BaseGenModel
18
+ from .lm import LMModel
19
+ from .builders import get_debug_compression_model, get_debug_lm_model
20
+ from .loaders import load_compression_model, load_lm_model
21
+
22
+
23
+ class AudioGen(BaseGenModel):
24
+ """AudioGen main model with convenient generation API.
25
+
26
+ Args:
27
+ name (str): name of the model.
28
+ compression_model (CompressionModel): Compression model
29
+ used to map audio to invertible discrete representations.
30
+ lm (LMModel): Language model over discrete representations.
31
+ max_duration (float, optional): maximum duration the model can produce,
32
+ otherwise, inferred from the training params.
33
+ """
34
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
35
+ max_duration: tp.Optional[float] = None):
36
+ super().__init__(name, compression_model, lm, max_duration)
37
+ self.set_generation_params(duration=5) # default duration
38
+
39
+ @staticmethod
40
+ def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
41
+ """Return pretrained model, we provide a single model for now:
42
+ - facebook/audiogen-medium (1.5B), text to sound,
43
+ # see: https://huggingface.co/facebook/audiogen-medium
44
+ """
45
+ if device is None:
46
+ if torch.cuda.device_count():
47
+ device = 'cuda'
48
+ else:
49
+ device = 'cpu'
50
+
51
+ if name == 'debug':
52
+ # used only for unit tests
53
+ compression_model = get_debug_compression_model(device, sample_rate=16000)
54
+ lm = get_debug_lm_model(device)
55
+ return AudioGen(name, compression_model, lm, max_duration=10)
56
+
57
+ compression_model = load_compression_model(name, device=device)
58
+ lm = load_lm_model(name, device=device)
59
+ assert 'self_wav' not in lm.condition_provider.conditioners, \
60
+ "AudioGen do not support waveform conditioning for now"
61
+ return AudioGen(name, compression_model, lm)
62
+
63
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
64
+ top_p: float = 0.0, temperature: float = 1.0,
65
+ duration: float = 10.0, cfg_coef: float = 3.0,
66
+ two_step_cfg: bool = False, extend_stride: float = 2):
67
+ """Set the generation parameters for AudioGen.
68
+
69
+ Args:
70
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
71
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
72
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
73
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
74
+ duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
75
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
76
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
77
+ instead of batching together the two. This has some impact on how things
78
+ are padded but seems to have little impact in practice.
79
+ extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
80
+ should we extend the audio each time. Larger values will mean less context is
81
+ preserved, and shorter value will require extra computations.
82
+ """
83
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
84
+ self.extend_stride = extend_stride
85
+ self.duration = duration
86
+ self.generation_params = {
87
+ 'use_sampling': use_sampling,
88
+ 'temp': temperature,
89
+ 'top_k': top_k,
90
+ 'top_p': top_p,
91
+ 'cfg_coef': cfg_coef,
92
+ 'two_step_cfg': two_step_cfg,
93
+ }
audiocraft/models/builders.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ All the functions to build the relevant models and modules
9
+ from the Hydra config.
10
+ """
11
+
12
+ import typing as tp
13
+
14
+ import audiocraft
15
+ import omegaconf
16
+ import torch
17
+
18
+ from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
19
+ from .lm import LMModel
20
+ from .lm_magnet import MagnetLMModel
21
+ from ..modules.codebooks_patterns import (
22
+ CodebooksPatternProvider,
23
+ DelayedPatternProvider,
24
+ MusicLMPattern,
25
+ ParallelPatternProvider,
26
+ UnrolledPatternProvider,
27
+ CoarseFirstPattern,
28
+ )
29
+ from ..modules.conditioners import (
30
+ BaseConditioner,
31
+ ChromaStemConditioner,
32
+ CLAPEmbeddingConditioner,
33
+ ConditionFuser,
34
+ ConditioningProvider,
35
+ LUTConditioner,
36
+ T5Conditioner,
37
+ )
38
+ from .unet import DiffusionUnet
39
+ from .. import quantization as qt
40
+ from ..utils.utils import dict_from_config
41
+ from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
42
+
43
+
44
+ def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
45
+ klass = {
46
+ 'no_quant': qt.DummyQuantizer,
47
+ 'rvq': qt.ResidualVectorQuantizer
48
+ }[quantizer]
49
+ kwargs = dict_from_config(getattr(cfg, quantizer))
50
+ if quantizer != 'no_quant':
51
+ kwargs['dimension'] = dimension
52
+ return klass(**kwargs)
53
+
54
+
55
+ def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
56
+ if encoder_name == 'seanet':
57
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
58
+ encoder_override_kwargs = kwargs.pop('encoder')
59
+ decoder_override_kwargs = kwargs.pop('decoder')
60
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
61
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
62
+ encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
63
+ decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
64
+ return encoder, decoder
65
+ else:
66
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
67
+
68
+
69
+ def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
70
+ """Instantiate a compression model."""
71
+ if cfg.compression_model == 'encodec':
72
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
73
+ encoder_name = kwargs.pop('autoencoder')
74
+ quantizer_name = kwargs.pop('quantizer')
75
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
76
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
77
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
78
+ renormalize = kwargs.pop('renormalize', False)
79
+ # deprecated params
80
+ kwargs.pop('renorm', None)
81
+ return EncodecModel(encoder, decoder, quantizer,
82
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
83
+ else:
84
+ raise KeyError(f"Unexpected compression model {cfg.compression_model}")
85
+
86
+
87
+ def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
88
+ """Instantiate a transformer LM."""
89
+ if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']:
90
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
91
+ n_q = kwargs['n_q']
92
+ q_modeling = kwargs.pop('q_modeling', None)
93
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
94
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
95
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
96
+ cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
97
+ fuser = get_condition_fuser(cfg)
98
+ condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
99
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
100
+ kwargs['cross_attention'] = True
101
+ if codebooks_pattern_cfg.modeling is None:
102
+ assert q_modeling is not None, \
103
+ "LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
104
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
105
+ {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
106
+ )
107
+
108
+ pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
109
+ lm_class = MagnetLMModel if cfg.lm_model == 'transformer_lm_magnet' else LMModel
110
+ return lm_class(
111
+ pattern_provider=pattern_provider,
112
+ condition_provider=condition_provider,
113
+ fuser=fuser,
114
+ cfg_dropout=cfg_prob,
115
+ cfg_coef=cfg_coef,
116
+ attribute_dropout=attribute_dropout,
117
+ dtype=getattr(torch, cfg.dtype),
118
+ device=cfg.device,
119
+ **kwargs
120
+ ).to(cfg.device)
121
+ else:
122
+ raise KeyError(f"Unexpected LM model {cfg.lm_model}")
123
+
124
+
125
+ def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
126
+ """Instantiate a conditioning model."""
127
+ device = cfg.device
128
+ duration = cfg.dataset.segment_duration
129
+ cfg = getattr(cfg, 'conditioners')
130
+ dict_cfg = {} if cfg is None else dict_from_config(cfg)
131
+ conditioners: tp.Dict[str, BaseConditioner] = {}
132
+ condition_provider_args = dict_cfg.pop('args', {})
133
+ condition_provider_args.pop('merge_text_conditions_p', None)
134
+ condition_provider_args.pop('drop_desc_p', None)
135
+
136
+ for cond, cond_cfg in dict_cfg.items():
137
+ model_type = cond_cfg['model']
138
+ model_args = cond_cfg[model_type]
139
+ if model_type == 't5':
140
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
141
+ elif model_type == 'lut':
142
+ conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
143
+ elif model_type == 'chroma_stem':
144
+ conditioners[str(cond)] = ChromaStemConditioner(
145
+ output_dim=output_dim,
146
+ duration=duration,
147
+ device=device,
148
+ **model_args
149
+ )
150
+ elif model_type == 'clap':
151
+ conditioners[str(cond)] = CLAPEmbeddingConditioner(
152
+ output_dim=output_dim,
153
+ device=device,
154
+ **model_args
155
+ )
156
+ else:
157
+ raise ValueError(f"Unrecognized conditioning model: {model_type}")
158
+ conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
159
+ return conditioner
160
+
161
+
162
+ def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
163
+ """Instantiate a condition fuser object."""
164
+ fuser_cfg = getattr(cfg, 'fuser')
165
+ fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
166
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
167
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
168
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
169
+ return fuser
170
+
171
+
172
+ def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
173
+ """Instantiate a codebooks pattern provider object."""
174
+ pattern_providers = {
175
+ 'parallel': ParallelPatternProvider,
176
+ 'delay': DelayedPatternProvider,
177
+ 'unroll': UnrolledPatternProvider,
178
+ 'coarse_first': CoarseFirstPattern,
179
+ 'musiclm': MusicLMPattern,
180
+ }
181
+ name = cfg.modeling
182
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
183
+ klass = pattern_providers[name]
184
+ return klass(n_q, **kwargs)
185
+
186
+
187
+ def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
188
+ """Instantiate a debug compression model to be used for unit tests."""
189
+ assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
190
+ model_ratios = {
191
+ 16000: [10, 8, 8], # 25 Hz at 16kHz
192
+ 32000: [10, 8, 16] # 25 Hz at 32kHz
193
+ }
194
+ ratios: tp.List[int] = model_ratios[sample_rate]
195
+ frame_rate = 25
196
+ seanet_kwargs: dict = {
197
+ 'n_filters': 4,
198
+ 'n_residual_layers': 1,
199
+ 'dimension': 32,
200
+ 'ratios': ratios,
201
+ }
202
+ encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
203
+ decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
204
+ quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
205
+ init_x = torch.randn(8, 32, 128)
206
+ quantizer(init_x, 1) # initialize kmeans etc.
207
+ compression_model = EncodecModel(
208
+ encoder, decoder, quantizer,
209
+ frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
210
+ return compression_model.eval()
211
+
212
+
213
+ def get_diffusion_model(cfg: omegaconf.DictConfig):
214
+ # TODO Find a way to infer the channels from dset
215
+ channels = cfg.channels
216
+ num_steps = cfg.schedule.num_steps
217
+ return DiffusionUnet(
218
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
219
+
220
+
221
+ def get_processor(cfg, sample_rate: int = 24000):
222
+ sample_processor = SampleProcessor()
223
+ if cfg.use:
224
+ kw = dict(cfg)
225
+ kw.pop('use')
226
+ kw.pop('name')
227
+ if cfg.name == "multi_band_processor":
228
+ sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
229
+ return sample_processor
230
+
231
+
232
+ def get_debug_lm_model(device='cpu'):
233
+ """Instantiate a debug LM to be used for unit tests."""
234
+ pattern = DelayedPatternProvider(n_q=4)
235
+ dim = 16
236
+ providers = {
237
+ 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
238
+ }
239
+ condition_provider = ConditioningProvider(providers)
240
+ fuser = ConditionFuser(
241
+ {'cross': ['description'], 'prepend': [],
242
+ 'sum': [], 'input_interpolate': []})
243
+ lm = LMModel(
244
+ pattern, condition_provider, fuser,
245
+ n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
246
+ cross_attention=True, causal=True)
247
+ return lm.to(device).eval()
248
+
249
+
250
+ def get_wrapped_compression_model(
251
+ compression_model: CompressionModel,
252
+ cfg: omegaconf.DictConfig) -> CompressionModel:
253
+ if hasattr(cfg, 'interleave_stereo_codebooks'):
254
+ if cfg.interleave_stereo_codebooks.use:
255
+ kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
256
+ kwargs.pop('use')
257
+ compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
258
+ if hasattr(cfg, 'compression_model_n_q'):
259
+ if cfg.compression_model_n_q is not None:
260
+ compression_model.set_num_codebooks(cfg.compression_model_n_q)
261
+ return compression_model
audiocraft/models/encodec.py ADDED
@@ -0,0 +1,506 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """Compression models or wrapper around existing models.
7
+ Also defines the main interface that a model must follow to be usable as an audio tokenizer.
8
+ """
9
+
10
+ from abc import ABC, abstractmethod
11
+ import logging
12
+ import math
13
+ from pathlib import Path
14
+ import typing as tp
15
+
16
+ from einops import rearrange
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ from transformers import EncodecModel as HFEncodecModel
21
+
22
+ from .. import quantization as qt
23
+
24
+
25
+ logger = logging.getLogger()
26
+
27
+
28
+ class CompressionModel(ABC, nn.Module):
29
+ """Base API for all compression models that aim at being used as audio tokenizers
30
+ with a language model.
31
+ """
32
+
33
+ @abstractmethod
34
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
35
+ ...
36
+
37
+ @abstractmethod
38
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
39
+ """See `EncodecModel.encode`."""
40
+ ...
41
+
42
+ @abstractmethod
43
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
44
+ """See `EncodecModel.decode`."""
45
+ ...
46
+
47
+ @abstractmethod
48
+ def decode_latent(self, codes: torch.Tensor):
49
+ """Decode from the discrete codes to continuous latent space."""
50
+ ...
51
+
52
+ @property
53
+ @abstractmethod
54
+ def channels(self) -> int:
55
+ ...
56
+
57
+ @property
58
+ @abstractmethod
59
+ def frame_rate(self) -> float:
60
+ ...
61
+
62
+ @property
63
+ @abstractmethod
64
+ def sample_rate(self) -> int:
65
+ ...
66
+
67
+ @property
68
+ @abstractmethod
69
+ def cardinality(self) -> int:
70
+ ...
71
+
72
+ @property
73
+ @abstractmethod
74
+ def num_codebooks(self) -> int:
75
+ ...
76
+
77
+ @property
78
+ @abstractmethod
79
+ def total_codebooks(self) -> int:
80
+ ...
81
+
82
+ @abstractmethod
83
+ def set_num_codebooks(self, n: int):
84
+ """Set the active number of codebooks used by the quantizer."""
85
+ ...
86
+
87
+ @staticmethod
88
+ def get_pretrained(
89
+ name: str, device: tp.Union[torch.device, str] = 'cpu'
90
+ ) -> 'CompressionModel':
91
+ """Instantiate a CompressionModel from a given pretrained model.
92
+
93
+ Args:
94
+ name (Path or str): name of the pretrained model. See after.
95
+ device (torch.device or str): Device on which the model is loaded.
96
+
97
+ Pretrained models:
98
+ - dac_44khz (https://github.com/descriptinc/descript-audio-codec)
99
+ - dac_24khz (same)
100
+ - facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
101
+ - facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
102
+ - your own model on Hugging Face. Export instructions to come...
103
+ """
104
+
105
+ from . import builders, loaders
106
+ model: CompressionModel
107
+ if name in ['dac_44khz', 'dac_24khz']:
108
+ model_type = name.split('_')[1]
109
+ logger.info("Getting pretrained compression model from DAC %s", model_type)
110
+ model = DAC(model_type)
111
+ elif name in ['debug_compression_model']:
112
+ logger.info("Getting pretrained compression model for debug")
113
+ model = builders.get_debug_compression_model()
114
+ elif Path(name).exists():
115
+ # We assume here if the path exists that it is in fact an AC checkpoint
116
+ # that was exported using `audiocraft.utils.export` functions.
117
+ model = loaders.load_compression_model(name, device=device)
118
+ else:
119
+ logger.info("Getting pretrained compression model from HF %s", name)
120
+ hf_model = HFEncodecModel.from_pretrained(name)
121
+ model = HFEncodecCompressionModel(hf_model).to(device)
122
+ return model.to(device).eval()
123
+
124
+
125
+ class EncodecModel(CompressionModel):
126
+ """Encodec model operating on the raw waveform.
127
+
128
+ Args:
129
+ encoder (nn.Module): Encoder network.
130
+ decoder (nn.Module): Decoder network.
131
+ quantizer (qt.BaseQuantizer): Quantizer network.
132
+ frame_rate (int): Frame rate for the latent representation.
133
+ sample_rate (int): Audio sample rate.
134
+ channels (int): Number of audio channels.
135
+ causal (bool): Whether to use a causal version of the model.
136
+ renormalize (bool): Whether to renormalize the audio before running the model.
137
+ """
138
+ # we need assignment to override the property in the abstract class,
139
+ # I couldn't find a better way...
140
+ frame_rate: float = 0
141
+ sample_rate: int = 0
142
+ channels: int = 0
143
+
144
+ def __init__(self,
145
+ encoder: nn.Module,
146
+ decoder: nn.Module,
147
+ quantizer: qt.BaseQuantizer,
148
+ frame_rate: int,
149
+ sample_rate: int,
150
+ channels: int,
151
+ causal: bool = False,
152
+ renormalize: bool = False):
153
+ super().__init__()
154
+ self.encoder = encoder
155
+ self.decoder = decoder
156
+ self.quantizer = quantizer
157
+ self.frame_rate = frame_rate
158
+ self.sample_rate = sample_rate
159
+ self.channels = channels
160
+ self.renormalize = renormalize
161
+ self.causal = causal
162
+ if self.causal:
163
+ # we force disabling here to avoid handling linear overlap of segments
164
+ # as supported in original EnCodec codebase.
165
+ assert not self.renormalize, 'Causal model does not support renormalize'
166
+
167
+ @property
168
+ def total_codebooks(self):
169
+ """Total number of quantizer codebooks available."""
170
+ return self.quantizer.total_codebooks
171
+
172
+ @property
173
+ def num_codebooks(self):
174
+ """Active number of codebooks used by the quantizer."""
175
+ return self.quantizer.num_codebooks
176
+
177
+ def set_num_codebooks(self, n: int):
178
+ """Set the active number of codebooks used by the quantizer."""
179
+ self.quantizer.set_num_codebooks(n)
180
+
181
+ @property
182
+ def cardinality(self):
183
+ """Cardinality of each codebook."""
184
+ return self.quantizer.bins
185
+
186
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
187
+ scale: tp.Optional[torch.Tensor]
188
+ if self.renormalize:
189
+ mono = x.mean(dim=1, keepdim=True)
190
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
191
+ scale = 1e-8 + volume
192
+ x = x / scale
193
+ scale = scale.view(-1, 1)
194
+ else:
195
+ scale = None
196
+ return x, scale
197
+
198
+ def postprocess(self,
199
+ x: torch.Tensor,
200
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
201
+ if scale is not None:
202
+ assert self.renormalize
203
+ x = x * scale.view(-1, 1, 1)
204
+ return x
205
+
206
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
207
+ assert x.dim() == 3
208
+ length = x.shape[-1]
209
+ x, scale = self.preprocess(x)
210
+
211
+ emb = self.encoder(x)
212
+ q_res = self.quantizer(emb, self.frame_rate)
213
+ out = self.decoder(q_res.x)
214
+
215
+ # remove extra padding added by the encoder and decoder
216
+ assert out.shape[-1] >= length, (out.shape[-1], length)
217
+ out = out[..., :length]
218
+
219
+ q_res.x = self.postprocess(out, scale)
220
+
221
+ return q_res
222
+
223
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
224
+ """Encode the given input tensor to quantized representation along with scale parameter.
225
+
226
+ Args:
227
+ x (torch.Tensor): Float tensor of shape [B, C, T]
228
+
229
+ Returns:
230
+ codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
231
+ codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
232
+ scale: a float tensor containing the scale for audio renormalization.
233
+ """
234
+ assert x.dim() == 3
235
+ x, scale = self.preprocess(x)
236
+ emb = self.encoder(x)
237
+ codes = self.quantizer.encode(emb)
238
+ return codes, scale
239
+
240
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
241
+ """Decode the given codes to a reconstructed representation, using the scale to perform
242
+ audio denormalization if needed.
243
+
244
+ Args:
245
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
246
+ scale (torch.Tensor, optional): Float tensor containing the scale value.
247
+
248
+ Returns:
249
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
250
+ """
251
+ emb = self.decode_latent(codes)
252
+ out = self.decoder(emb)
253
+ out = self.postprocess(out, scale)
254
+ # out contains extra padding added by the encoder and decoder
255
+ return out
256
+
257
+ def decode_latent(self, codes: torch.Tensor):
258
+ """Decode from the discrete codes to continuous latent space."""
259
+ return self.quantizer.decode(codes)
260
+
261
+
262
+ class DAC(CompressionModel):
263
+ def __init__(self, model_type: str = "44khz"):
264
+ super().__init__()
265
+ try:
266
+ import dac.utils
267
+ except ImportError:
268
+ raise RuntimeError("Could not import dac, make sure it is installed, "
269
+ "please run `pip install descript-audio-codec`")
270
+ self.model = dac.utils.load_model(model_type=model_type)
271
+ self.n_quantizers = self.total_codebooks
272
+ self.model.eval()
273
+
274
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
275
+ # We don't support training with this.
276
+ raise NotImplementedError("Forward and training with DAC not supported.")
277
+
278
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
279
+ codes = self.model.encode(x, self.n_quantizers)[1]
280
+ return codes[:, :self.n_quantizers], None
281
+
282
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
283
+ assert scale is None
284
+ z_q = self.decode_latent(codes)
285
+ return self.model.decode(z_q)
286
+
287
+ def decode_latent(self, codes: torch.Tensor):
288
+ """Decode from the discrete codes to continuous latent space."""
289
+ return self.model.quantizer.from_codes(codes)[0]
290
+
291
+ @property
292
+ def channels(self) -> int:
293
+ return 1
294
+
295
+ @property
296
+ def frame_rate(self) -> float:
297
+ return self.model.sample_rate / self.model.hop_length
298
+
299
+ @property
300
+ def sample_rate(self) -> int:
301
+ return self.model.sample_rate
302
+
303
+ @property
304
+ def cardinality(self) -> int:
305
+ return self.model.codebook_size
306
+
307
+ @property
308
+ def num_codebooks(self) -> int:
309
+ return self.n_quantizers
310
+
311
+ @property
312
+ def total_codebooks(self) -> int:
313
+ return self.model.n_codebooks
314
+
315
+ def set_num_codebooks(self, n: int):
316
+ """Set the active number of codebooks used by the quantizer.
317
+ """
318
+ assert n >= 1
319
+ assert n <= self.total_codebooks
320
+ self.n_quantizers = n
321
+
322
+
323
+ class HFEncodecCompressionModel(CompressionModel):
324
+ """Wrapper around HuggingFace Encodec.
325
+ """
326
+ def __init__(self, model: HFEncodecModel):
327
+ super().__init__()
328
+ self.model = model
329
+ bws = self.model.config.target_bandwidths
330
+ num_codebooks = [
331
+ bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
332
+ for bw in bws
333
+ ]
334
+ deltas = [nc - int(nc) for nc in num_codebooks]
335
+ # Checking we didn't do some bad maths and we indeed have integers!
336
+ assert all(deltas) <= 1e-3, deltas
337
+ self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
338
+ self.set_num_codebooks(max(self.possible_num_codebooks))
339
+
340
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
341
+ # We don't support training with this.
342
+ raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
343
+
344
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
345
+ bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
346
+ bandwidth = self.model.config.target_bandwidths[bandwidth_index]
347
+ res = self.model.encode(x, None, bandwidth)
348
+ assert len(res[0]) == 1
349
+ assert len(res[1]) == 1
350
+ return res[0][0], res[1][0]
351
+
352
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
353
+ if scale is None:
354
+ scales = [None] # type: ignore
355
+ else:
356
+ scales = scale # type: ignore
357
+ res = self.model.decode(codes[None], scales)
358
+ return res[0]
359
+
360
+ def decode_latent(self, codes: torch.Tensor):
361
+ """Decode from the discrete codes to continuous latent space."""
362
+ return self.model.quantizer.decode(codes.transpose(0, 1))
363
+
364
+ @property
365
+ def channels(self) -> int:
366
+ return self.model.config.audio_channels
367
+
368
+ @property
369
+ def frame_rate(self) -> float:
370
+ hop_length = int(np.prod(self.model.config.upsampling_ratios))
371
+ return self.sample_rate / hop_length
372
+
373
+ @property
374
+ def sample_rate(self) -> int:
375
+ return self.model.config.sampling_rate
376
+
377
+ @property
378
+ def cardinality(self) -> int:
379
+ return self.model.config.codebook_size
380
+
381
+ @property
382
+ def num_codebooks(self) -> int:
383
+ return self._num_codebooks
384
+
385
+ @property
386
+ def total_codebooks(self) -> int:
387
+ return max(self.possible_num_codebooks)
388
+
389
+ def set_num_codebooks(self, n: int):
390
+ """Set the active number of codebooks used by the quantizer.
391
+ """
392
+ if n not in self.possible_num_codebooks:
393
+ raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
394
+ self._num_codebooks = n
395
+
396
+
397
+ class InterleaveStereoCompressionModel(CompressionModel):
398
+ """Wraps a CompressionModel to support stereo inputs. The wrapped model
399
+ will be applied independently to the left and right channels, and both codebooks
400
+ will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
401
+ channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
402
+ `per_timestep`.
403
+
404
+ Args:
405
+ model (CompressionModel): Compression model to wrap.
406
+ per_timestep (bool): Whether to interleave on the timestep dimension
407
+ or on the codebooks dimension.
408
+ """
409
+ def __init__(self, model: CompressionModel, per_timestep: bool = False):
410
+ super().__init__()
411
+ self.model = model
412
+ self.per_timestep = per_timestep
413
+ assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
414
+
415
+ @property
416
+ def total_codebooks(self):
417
+ return self.model.total_codebooks
418
+
419
+ @property
420
+ def num_codebooks(self):
421
+ """Active number of codebooks used by the quantizer.
422
+
423
+ ..Warning:: this reports the number of codebooks after the interleaving
424
+ of the codebooks!
425
+ """
426
+ return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
427
+
428
+ def set_num_codebooks(self, n: int):
429
+ """Set the active number of codebooks used by the quantizer.
430
+
431
+ ..Warning:: this sets the number of codebooks before the interleaving!
432
+ """
433
+ self.model.set_num_codebooks(n)
434
+
435
+ @property
436
+ def num_virtual_steps(self) -> float:
437
+ """Return the number of virtual steps, e.g. one real step
438
+ will be split into that many steps.
439
+ """
440
+ return 2 if self.per_timestep else 1
441
+
442
+ @property
443
+ def frame_rate(self) -> float:
444
+ return self.model.frame_rate * self.num_virtual_steps
445
+
446
+ @property
447
+ def sample_rate(self) -> int:
448
+ return self.model.sample_rate
449
+
450
+ @property
451
+ def channels(self) -> int:
452
+ return 2
453
+
454
+ @property
455
+ def cardinality(self):
456
+ """Cardinality of each codebook.
457
+ """
458
+ return self.model.cardinality
459
+
460
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
461
+ raise NotImplementedError("Not supported, use encode and decode.")
462
+
463
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
464
+ B, C, T = x.shape
465
+ assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
466
+
467
+ indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
468
+ indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
469
+ indices = torch.stack([indices_c0, indices_c1], dim=0)
470
+ scales: tp.Optional[torch.Tensor] = None
471
+ if scales_c0 is not None and scales_c1 is not None:
472
+ scales = torch.stack([scales_c0, scales_c1], dim=1)
473
+
474
+ if self.per_timestep:
475
+ indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
476
+ else:
477
+ indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
478
+
479
+ return (indices, scales)
480
+
481
+ def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
482
+ if self.per_timestep:
483
+ codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
484
+ else:
485
+ codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
486
+ return codes[0], codes[1]
487
+
488
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
489
+ B, K, T = codes.shape
490
+ assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
491
+ assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
492
+
493
+ scale_c0, scale_c1 = None, None
494
+ if scale is not None:
495
+ assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
496
+ scale_c0 = scale[0, ...]
497
+ scale_c1 = scale[1, ...]
498
+
499
+ codes_c0, codes_c1 = self.get_left_right_codes(codes)
500
+ audio_c0 = self.model.decode(codes_c0, scale_c0)
501
+ audio_c1 = self.model.decode(codes_c1, scale_c1)
502
+ return torch.cat([audio_c0, audio_c1], dim=1)
503
+
504
+ def decode_latent(self, codes: torch.Tensor):
505
+ """Decode from the discrete codes to continuous latent space."""
506
+ raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
audiocraft/models/genmodel.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Base implementation for audio generative models. This base implementation
9
+ combines all the required components to run inference with pretrained audio
10
+ generative models. It can be easily inherited by downstream model classes to
11
+ provide easy access to the generation API.
12
+ """
13
+
14
+ from abc import ABC, abstractmethod
15
+ import typing as tp
16
+
17
+ import omegaconf
18
+ import torch
19
+
20
+ from .encodec import CompressionModel
21
+ from .lm import LMModel
22
+ from .builders import get_wrapped_compression_model
23
+ from ..data.audio_utils import convert_audio
24
+ from ..modules.conditioners import ConditioningAttributes
25
+ from ..utils.autocast import TorchAutocast
26
+
27
+
28
+ class BaseGenModel(ABC):
29
+ """Base generative model with convenient generation API.
30
+
31
+ Args:
32
+ name (str): name of the model.
33
+ compression_model (CompressionModel): Compression model
34
+ used to map audio to invertible discrete representations.
35
+ lm (LMModel): Language model over discrete representations.
36
+ max_duration (float, optional): maximum duration the model can produce,
37
+ otherwise, inferred from the training params.
38
+ """
39
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
40
+ max_duration: tp.Optional[float] = None):
41
+ self.name = name
42
+ self.compression_model = compression_model
43
+ self.lm = lm
44
+ self.cfg: tp.Optional[omegaconf.DictConfig] = None
45
+ # Just to be safe, let's put everything in eval mode.
46
+ self.compression_model.eval()
47
+ self.lm.eval()
48
+
49
+ if hasattr(lm, 'cfg'):
50
+ cfg = lm.cfg
51
+ assert isinstance(cfg, omegaconf.DictConfig)
52
+ self.cfg = cfg
53
+
54
+ if self.cfg is not None:
55
+ self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
56
+
57
+ if max_duration is None:
58
+ if self.cfg is not None:
59
+ max_duration = lm.cfg.dataset.segment_duration # type: ignore
60
+ else:
61
+ raise ValueError("You must provide max_duration when building directly your GenModel")
62
+ assert max_duration is not None
63
+
64
+ self.max_duration: float = max_duration
65
+ self.duration = self.max_duration
66
+
67
+ # self.extend_stride is the length of audio extension when generating samples longer
68
+ # than self.max_duration. NOTE: the derived class must set self.extend_stride to a
69
+ # positive float value when generating with self.duration > self.max_duration.
70
+ self.extend_stride: tp.Optional[float] = None
71
+ self.device = next(iter(lm.parameters())).device
72
+ self.generation_params: dict = {}
73
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
74
+ if self.device.type == 'cpu':
75
+ self.autocast = TorchAutocast(enabled=False)
76
+ else:
77
+ self.autocast = TorchAutocast(
78
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
79
+
80
+ @property
81
+ def frame_rate(self) -> float:
82
+ """Roughly the number of AR steps per seconds."""
83
+ return self.compression_model.frame_rate
84
+
85
+ @property
86
+ def sample_rate(self) -> int:
87
+ """Sample rate of the generated audio."""
88
+ return self.compression_model.sample_rate
89
+
90
+ @property
91
+ def audio_channels(self) -> int:
92
+ """Audio channels of the generated audio."""
93
+ return self.compression_model.channels
94
+
95
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
96
+ """Override the default progress callback."""
97
+ self._progress_callback = progress_callback
98
+
99
+ @abstractmethod
100
+ def set_generation_params(self, *args, **kwargs):
101
+ """Set the generation parameters."""
102
+ raise NotImplementedError("No base implementation for setting generation params.")
103
+
104
+ @staticmethod
105
+ @abstractmethod
106
+ def get_pretrained(name: str, device=None):
107
+ raise NotImplementedError("No base implementation for getting pretrained model")
108
+
109
+ @torch.no_grad()
110
+ def _prepare_tokens_and_attributes(
111
+ self,
112
+ descriptions: tp.Sequence[tp.Optional[str]],
113
+ prompt: tp.Optional[torch.Tensor],
114
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
115
+ """Prepare model inputs.
116
+
117
+ Args:
118
+ descriptions (list of str): A list of strings used as text conditioning.
119
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
120
+ """
121
+ attributes = [
122
+ ConditioningAttributes(text={'description': description})
123
+ for description in descriptions]
124
+
125
+ if prompt is not None:
126
+ if descriptions is not None:
127
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
128
+ prompt = prompt.to(self.device)
129
+ prompt_tokens, scale = self.compression_model.encode(prompt)
130
+ assert scale is None
131
+ else:
132
+ prompt_tokens = None
133
+ return attributes, prompt_tokens
134
+
135
+ def generate_unconditional(self, num_samples: int, progress: bool = False,
136
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
137
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
138
+ """Generate samples in an unconditional manner.
139
+
140
+ Args:
141
+ num_samples (int): Number of samples to be generated.
142
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
143
+ """
144
+ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
145
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
146
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
147
+ if return_tokens:
148
+ return self.generate_audio(tokens), tokens
149
+ return self.generate_audio(tokens)
150
+
151
+ def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
152
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
153
+ """Generate samples conditioned on text.
154
+
155
+ Args:
156
+ descriptions (list of str): A list of strings used as text conditioning.
157
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
158
+ """
159
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
160
+ assert prompt_tokens is None
161
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
162
+ if return_tokens:
163
+ return self.generate_audio(tokens), tokens
164
+ return self.generate_audio(tokens)
165
+
166
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
167
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
168
+ progress: bool = False, return_tokens: bool = False) \
169
+ -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
170
+ """Generate samples conditioned on audio prompts and an optional text description.
171
+
172
+ Args:
173
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
174
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
175
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
176
+ descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
177
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
178
+ """
179
+ if prompt.dim() == 2:
180
+ prompt = prompt[None]
181
+ if prompt.dim() != 3:
182
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
183
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
184
+ if descriptions is None:
185
+ descriptions = [None] * len(prompt)
186
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
187
+ assert prompt_tokens is not None
188
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
189
+ if return_tokens:
190
+ return self.generate_audio(tokens), tokens
191
+ return self.generate_audio(tokens)
192
+
193
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
194
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
195
+ """Generate discrete audio tokens given audio prompt and/or conditions.
196
+
197
+ Args:
198
+ attributes (list of ConditioningAttributes): Conditions used for generation (here text).
199
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
200
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
201
+ Returns:
202
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
203
+ """
204
+ total_gen_len = int(self.duration * self.frame_rate)
205
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
206
+ current_gen_offset: int = 0
207
+
208
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
209
+ generated_tokens += current_gen_offset
210
+ if self._progress_callback is not None:
211
+ # Note that total_gen_len might be quite wrong depending on the
212
+ # codebook pattern used, but with delay it is almost accurate.
213
+ self._progress_callback(generated_tokens, tokens_to_generate)
214
+ else:
215
+ print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
216
+
217
+ if prompt_tokens is not None:
218
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
219
+ "Prompt is longer than audio to generate"
220
+
221
+ callback = None
222
+ if progress:
223
+ callback = _progress_callback
224
+
225
+ if self.duration <= self.max_duration:
226
+ # generate by sampling from LM, simple case.
227
+ with self.autocast:
228
+ gen_tokens = self.lm.generate(
229
+ prompt_tokens, attributes,
230
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
231
+
232
+ else:
233
+ assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
234
+ assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
235
+ all_tokens = []
236
+ if prompt_tokens is None:
237
+ prompt_length = 0
238
+ else:
239
+ all_tokens.append(prompt_tokens)
240
+ prompt_length = prompt_tokens.shape[-1]
241
+
242
+ stride_tokens = int(self.frame_rate * self.extend_stride)
243
+ while current_gen_offset + prompt_length < total_gen_len:
244
+ time_offset = current_gen_offset / self.frame_rate
245
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
246
+ max_gen_len = int(chunk_duration * self.frame_rate)
247
+ with self.autocast:
248
+ gen_tokens = self.lm.generate(
249
+ prompt_tokens, attributes,
250
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
251
+ if prompt_tokens is None:
252
+ all_tokens.append(gen_tokens)
253
+ else:
254
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
255
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
256
+ prompt_length = prompt_tokens.shape[-1]
257
+ current_gen_offset += stride_tokens
258
+
259
+ gen_tokens = torch.cat(all_tokens, dim=-1)
260
+ return gen_tokens
261
+
262
+ def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
263
+ """Generate Audio from tokens."""
264
+ assert gen_tokens.dim() == 3
265
+ with torch.no_grad():
266
+ gen_audio = self.compression_model.decode(gen_tokens, None)
267
+ return gen_audio
audiocraft/models/lm.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from functools import partial
9
+ import logging
10
+ import math
11
+ import typing as tp
12
+
13
+ import torch
14
+ from torch import nn
15
+
16
+ from ..utils import utils
17
+ from ..modules.streaming import StreamingModule, State
18
+ from ..modules.transformer import StreamingTransformer, create_norm_fn
19
+ from ..modules.conditioners import (
20
+ ConditionFuser,
21
+ ClassifierFreeGuidanceDropout,
22
+ AttributeDropout,
23
+ ConditioningProvider,
24
+ ConditioningAttributes,
25
+ ConditionType,
26
+ )
27
+ from ..modules.codebooks_patterns import CodebooksPatternProvider
28
+ from ..modules.activations import get_activation_fn
29
+
30
+
31
+ logger = logging.getLogger(__name__)
32
+ ConditionTensors = tp.Dict[str, ConditionType]
33
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
34
+
35
+
36
+ def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
37
+ """LM layer initialization.
38
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
39
+
40
+ Args:
41
+ method (str): Method name for init function. Valid options are:
42
+ 'gaussian', 'uniform'.
43
+ input_dim (int): Input dimension of the initialized module.
44
+ init_depth (int, optional): Optional init depth value used to rescale
45
+ the standard deviation if defined.
46
+ """
47
+ # Compute std
48
+ std = 1 / math.sqrt(input_dim)
49
+ # Rescale with depth
50
+ if init_depth is not None:
51
+ std = std / math.sqrt(2 * init_depth)
52
+
53
+ if method == 'gaussian':
54
+ return partial(
55
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
56
+ )
57
+ elif method == 'uniform':
58
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
59
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
60
+ else:
61
+ raise ValueError("Unsupported layer initialization method")
62
+
63
+
64
+ def init_layer(m: nn.Module,
65
+ method: str,
66
+ init_depth: tp.Optional[int] = None,
67
+ zero_bias_init: bool = False):
68
+ """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
69
+
70
+ Args:
71
+ m (nn.Module): Module to initialize.
72
+ method (str): Method name for the init function.
73
+ init_depth (int, optional): Optional init depth value used to rescale
74
+ the standard deviation if defined.
75
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
76
+ """
77
+ if isinstance(m, nn.Linear):
78
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
79
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
80
+ weight = m.weight.float()
81
+ init_fn(weight)
82
+ m.weight.data[:] = weight.half()
83
+ else:
84
+ init_fn(m.weight)
85
+ if zero_bias_init and m.bias is not None:
86
+ nn.init.constant_(m.bias, 0)
87
+ elif isinstance(m, nn.Embedding):
88
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
89
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
90
+ weight = m.weight.float()
91
+ init_fn(weight)
92
+ m.weight.data[:] = weight.half()
93
+ else:
94
+ init_fn(m.weight)
95
+
96
+
97
+ class ScaledEmbedding(nn.Embedding):
98
+ """Boost learning rate for embeddings (with `scale`).
99
+ """
100
+ def __init__(self, *args, lr=None, **kwargs):
101
+ super().__init__(*args, **kwargs)
102
+ self.lr = lr
103
+
104
+ def make_optim_group(self):
105
+ group = {"params": list(self.parameters())}
106
+ if self.lr is not None:
107
+ group["lr"] = self.lr
108
+ return group
109
+
110
+
111
+ @dataclass
112
+ class LMOutput:
113
+ # The logits are already re-aligned with the input codes
114
+ # hence no extra shift is required, e.g. when computing CE
115
+ logits: torch.Tensor # [B, K, T, card]
116
+ mask: torch.Tensor # [B, K, T]
117
+
118
+
119
+ class LMModel(StreamingModule):
120
+ """Transformer-based language model on multiple streams of codes.
121
+
122
+ Args:
123
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
124
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
125
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
126
+ n_q (int): Number of parallel streams to model.
127
+ card (int): Cardinality, vocabulary size.
128
+ dim (int): Dimension of the transformer encoder.
129
+ num_heads (int): Number of heads for the transformer encoder.
130
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
131
+ norm (str): Normalization method.
132
+ norm_first (bool): Use pre-norm instead of post-norm.
133
+ emb_lr (float, optional): Embedding-specific learning rate.
134
+ bias_proj (bool): Use bias for output projections.
135
+ weight_init (str, optional): Method for weight initialization.
136
+ depthwise_init (str, optional): Method for depthwise weight initialization.
137
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
138
+ cfg_dropout (float): Classifier-free guidance dropout.
139
+ cfg_coef (float): Classifier-free guidance coefficient.
140
+ attribute_dropout (dict): Attribute dropout probabilities.
141
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
142
+ **kwargs: Additional parameters for the transformer encoder.
143
+ """
144
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
145
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
146
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
147
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
148
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
149
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
150
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
151
+ **kwargs):
152
+ super().__init__()
153
+ self.cfg_coef = cfg_coef
154
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
155
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
156
+ self.condition_provider = condition_provider
157
+ self.fuser = fuser
158
+ self.card = card
159
+ embed_dim = self.card + 1
160
+ self.n_q = n_q
161
+ self.dim = dim
162
+ self.pattern_provider = pattern_provider
163
+ self.two_step_cfg = two_step_cfg
164
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
165
+ if 'activation' in kwargs:
166
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
167
+ self.transformer = StreamingTransformer(
168
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
169
+ norm=norm, norm_first=norm_first, **kwargs)
170
+ self.out_norm: tp.Optional[nn.Module] = None
171
+ if norm_first:
172
+ self.out_norm = create_norm_fn(norm, dim)
173
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
174
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
175
+ self._fsdp: tp.Optional[nn.Module]
176
+ self.__dict__['_fsdp'] = None
177
+
178
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
179
+ """Initialization of the transformer module weights.
180
+
181
+ Args:
182
+ weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
183
+ depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
184
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
185
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
186
+ zero_bias_init (bool): Whether to initialize bias to zero or not.
187
+ """
188
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
189
+ assert depthwise_init is None or weight_init is not None, \
190
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
191
+ assert not zero_bias_init or weight_init is not None, \
192
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
193
+
194
+ if weight_init is None:
195
+ return
196
+
197
+ for emb_layer in self.emb:
198
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
199
+
200
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
201
+ depth = None
202
+ if depthwise_init == 'current':
203
+ depth = layer_idx + 1
204
+ elif depthwise_init == 'global':
205
+ depth = len(self.transformer.layers)
206
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
207
+ tr_layer.apply(init_fn)
208
+
209
+ for linear in self.linears:
210
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
211
+
212
+ @property
213
+ def special_token_id(self) -> int:
214
+ return self.card
215
+
216
+ @property
217
+ def num_codebooks(self) -> int:
218
+ return self.n_q
219
+
220
+ def forward(self, sequence: torch.Tensor,
221
+ conditions: tp.List[ConditioningAttributes],
222
+ condition_tensors: tp.Optional[ConditionTensors] = None,
223
+ stage: int = -1) -> torch.Tensor:
224
+ """Apply language model on sequence and conditions.
225
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
226
+ S the sequence steps, return the logits with shape [B, card, K, S].
227
+
228
+ Args:
229
+ indices (torch.Tensor): Indices of the codes to model.
230
+ conditions (list of ConditioningAttributes): Conditions to use when modeling
231
+ the given codes. Note that when evaluating multiple time with the same conditioning
232
+ you should pre-compute those and pass them as `condition_tensors`.
233
+ condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
234
+ tensors, see `conditions`.
235
+ stage (int): The codebook level that is being predicted. Relevant for MAGNeT
236
+ in which prediction is done in a codebook-by-codebook manner.
237
+ Takes values in range(n_q), and ignored by default.
238
+ Returns:
239
+ torch.Tensor: Logits.
240
+ """
241
+ B, K, S = sequence.shape
242
+ assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
243
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
244
+ if condition_tensors is None:
245
+ assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
246
+ # apply dropout modules
247
+ conditions = self.cfg_dropout(conditions)
248
+ conditions = self.att_dropout(conditions)
249
+ tokenized = self.condition_provider.tokenize(conditions)
250
+ # encode conditions and fuse, both have a streaming cache to not recompute when generating.
251
+ condition_tensors = self.condition_provider(tokenized)
252
+ else:
253
+ assert not conditions, "Shouldn't pass both conditions and condition_tensors."
254
+
255
+ input_, cross_attention_input = self.fuser(input_, condition_tensors)
256
+
257
+ out = self.transformer(input_, cross_attention_src=cross_attention_input,
258
+ src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
259
+ if self.out_norm:
260
+ out = self.out_norm(out)
261
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
262
+
263
+ # remove the prefix from the model outputs
264
+ if len(self.fuser.fuse2cond['prepend']) > 0:
265
+ logits = logits[:, :, -S:]
266
+
267
+ return logits # [B, K, S, card]
268
+
269
+ def compute_predictions(
270
+ self, codes: torch.Tensor,
271
+ conditions: tp.List[ConditioningAttributes],
272
+ condition_tensors: tp.Optional[ConditionTensors] = None,
273
+ stage: int = -1,
274
+ keep_only_valid_steps: bool = True) -> LMOutput:
275
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
276
+ forward using the specified codes interleaving pattern.
277
+
278
+ Args:
279
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
280
+ K the number of codebooks and T the number of timesteps.
281
+ conditions (list of ConditioningAttributes): conditionings to use when modeling
282
+ the given codes. Note that when evaluating multiple time with the same conditioning
283
+ you should pre-compute those and pass them as `condition_tensors`.
284
+ condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
285
+ tensors, see `conditions`.
286
+ stage (int): The codebook level that is being predicted. Relevant for MAGNeT
287
+ in which prediction is done in a codebook-by-codebook manner.
288
+ Takes values in range(n_q), and ignored by default.
289
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
290
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
291
+ Returns:
292
+ LMOutput: Language model outputs
293
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
294
+ i.e. the first item corresponds to logits to predict the first code, meaning that
295
+ no additional shifting of codes and logits is required.
296
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
297
+ Given the specified interleaving strategies, parts of the logits and codes should
298
+ not be considered as valid predictions because of invalid context.
299
+ """
300
+ B, K, T = codes.shape
301
+ codes = codes.contiguous()
302
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
303
+ pattern = self.pattern_provider.get_pattern(T)
304
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
305
+ codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
306
+ )
307
+
308
+ # apply model on pattern sequence
309
+ model = self if self._fsdp is None else self._fsdp
310
+ logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
311
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
312
+ # and provide the corresponding mask over invalid positions of tokens
313
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
314
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
315
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
316
+ logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
317
+ )
318
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
319
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
320
+ return LMOutput(logits, logits_mask)
321
+
322
+ def _sample_next_token(self,
323
+ sequence: torch.Tensor,
324
+ cfg_conditions: CFGConditions,
325
+ unconditional_state: State,
326
+ use_sampling: bool = False,
327
+ temp: float = 1.0,
328
+ top_k: int = 0,
329
+ top_p: float = 0.0,
330
+ cfg_coef: tp.Optional[float] = None,
331
+ two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
332
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
333
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
334
+
335
+ Args:
336
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
337
+ with K corresponding to the number of codebooks and S the number of sequence steps.
338
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
339
+ condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
340
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
341
+ use_sampling (bool): Whether to use a sampling strategy or not.
342
+ temp (float): Sampling temperature.
343
+ top_k (int): K for "top-k" sampling.
344
+ top_p (float): P for "top-p" sampling.
345
+ cfg_coef (float, optional): classifier free guidance coefficient
346
+ Returns:
347
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
348
+ """
349
+ B = sequence.shape[0]
350
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
351
+ model = self if self._fsdp is None else self._fsdp
352
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
353
+ if two_step_cfg and cfg_conditions != {}:
354
+ assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
355
+ condition_tensors, null_condition_tensors = cfg_conditions
356
+ cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
357
+ state = self.get_streaming_state()
358
+ self.set_streaming_state(unconditional_state)
359
+ uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
360
+ unconditional_state.update(self.get_streaming_state())
361
+ self.set_streaming_state(state)
362
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
363
+ else:
364
+ assert isinstance(cfg_conditions, dict)
365
+ condition_tensors = cfg_conditions
366
+ if condition_tensors:
367
+ # Preparing for CFG, predicting both conditional and unconditional logits.
368
+ sequence = torch.cat([sequence, sequence], dim=0)
369
+ all_logits = model(
370
+ sequence,
371
+ conditions=[], condition_tensors=condition_tensors)
372
+ if condition_tensors:
373
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
374
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
375
+ else:
376
+ logits = all_logits
377
+
378
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
379
+ logits = logits[..., -1] # [B x K x card]
380
+
381
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
382
+ if use_sampling and temp > 0.0:
383
+ probs = torch.softmax(logits / temp, dim=-1)
384
+ if top_p > 0.0:
385
+ next_token = utils.sample_top_p(probs, p=top_p)
386
+ elif top_k > 0:
387
+ next_token = utils.sample_top_k(probs, k=top_k)
388
+ else:
389
+ next_token = utils.multinomial(probs, num_samples=1)
390
+ else:
391
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
392
+
393
+ return next_token
394
+
395
+ @torch.no_grad()
396
+ def generate(self,
397
+ prompt: tp.Optional[torch.Tensor] = None,
398
+ conditions: tp.List[ConditioningAttributes] = [],
399
+ num_samples: tp.Optional[int] = None,
400
+ max_gen_len: int = 256,
401
+ use_sampling: bool = True,
402
+ temp: float = 1.0,
403
+ top_k: int = 250,
404
+ top_p: float = 0.0,
405
+ cfg_coef: tp.Optional[float] = None,
406
+ two_step_cfg: tp.Optional[bool] = None,
407
+ remove_prompts: bool = False,
408
+ check: bool = False,
409
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
410
+ **kwargs) -> torch.Tensor:
411
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
412
+ be performed in a greedy fashion or using sampling with top K and top P strategies.
413
+
414
+ Args:
415
+ prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
416
+ conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
417
+ num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
418
+ max_gen_len (int): Maximum generation length.
419
+ use_sampling (bool): Whether to use a sampling strategy or not.
420
+ temp (float): Sampling temperature.
421
+ top_k (int): K for "top-k" sampling.
422
+ top_p (float): P for "top-p" sampling.
423
+ cfg_coeff (float, optional): Classifier-free guidance coefficient.
424
+ two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
425
+ remove_prompts (bool): Whether to remove prompts from generation or not.
426
+ check (bool): Whether to apply further checks on generated sequence.
427
+ callback (Callback, optional): Callback function to report generation progress.
428
+ Returns:
429
+ torch.Tensor: Generated tokens.
430
+ """
431
+ assert not self.training, "generation shouldn't be used in training mode."
432
+ first_param = next(iter(self.parameters()))
433
+ device = first_param.device
434
+
435
+ # Checking all input shapes are consistent.
436
+ possible_num_samples = []
437
+ if num_samples is not None:
438
+ possible_num_samples.append(num_samples)
439
+ elif prompt is not None:
440
+ possible_num_samples.append(prompt.shape[0])
441
+ elif conditions:
442
+ possible_num_samples.append(len(conditions))
443
+ else:
444
+ possible_num_samples.append(1)
445
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
446
+ num_samples = possible_num_samples[0]
447
+
448
+ # below we create set of conditions: one conditional and one unconditional
449
+ # to do that we merge the regular condition together with the null condition
450
+ # we then do 1 forward pass instead of 2.
451
+ # the reason for that is two-fold:
452
+ # 1. it is about x2 faster than doing 2 forward passes
453
+ # 2. avoid the streaming API treating the 2 passes as part of different time steps
454
+ # We also support doing two different passes, in particular to ensure that
455
+ # the padding structure is exactly the same between train and test.
456
+ # With a batch size of 1, this can be slower though.
457
+ cfg_conditions: CFGConditions
458
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
459
+ if conditions:
460
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
461
+ if two_step_cfg:
462
+ cfg_conditions = (
463
+ self.condition_provider(self.condition_provider.tokenize(conditions)),
464
+ self.condition_provider(self.condition_provider.tokenize(null_conditions)),
465
+ )
466
+ else:
467
+ conditions = conditions + null_conditions
468
+ tokenized = self.condition_provider.tokenize(conditions)
469
+ cfg_conditions = self.condition_provider(tokenized)
470
+ else:
471
+ cfg_conditions = {}
472
+
473
+ if prompt is None:
474
+ assert num_samples > 0
475
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
476
+
477
+ B, K, T = prompt.shape
478
+ start_offset = T
479
+ assert start_offset < max_gen_len
480
+
481
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
482
+ # this token is used as default value for codes that are not generated yet
483
+ unknown_token = -1
484
+
485
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
486
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
487
+ # filling the gen_codes with the prompt if needed
488
+ gen_codes[..., :start_offset] = prompt
489
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
490
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
491
+ # retrieve the start_offset in the sequence:
492
+ # it is the first sequence step that contains the `start_offset` timestep
493
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
494
+ assert start_offset_sequence is not None
495
+
496
+ with self.streaming():
497
+ unconditional_state = self.get_streaming_state()
498
+ prev_offset = 0
499
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
500
+ for offset in range(start_offset_sequence, gen_sequence_len):
501
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
502
+ curr_sequence = gen_sequence[..., prev_offset:offset]
503
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
504
+ if check:
505
+ # check coherence between mask and sequence
506
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
507
+ # should never happen as gen_sequence is filled progressively
508
+ assert not (curr_sequence == unknown_token).any()
509
+ # sample next token from the model, next token shape is [B, K, 1]
510
+ next_token = self._sample_next_token(
511
+ curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
512
+ cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
513
+ # ensure the tokens that should be masked are properly set to special_token_id
514
+ # as the model never output special_token_id
515
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
516
+ next_token[~valid_mask] = self.special_token_id
517
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
518
+ # (then mask tokens should be left as is as well, which is correct)
519
+ gen_sequence[..., offset:offset+1] = torch.where(
520
+ gen_sequence[..., offset:offset+1] == unknown_token,
521
+ next_token, gen_sequence[..., offset:offset+1]
522
+ )
523
+ prev_offset = offset
524
+ if callback is not None:
525
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
526
+ unconditional_state.clear()
527
+
528
+ # ensure sequence has been entirely filled
529
+ assert not (gen_sequence == unknown_token).any()
530
+ # ensure gen_sequence pattern and mask are matching
531
+ # which means the gen_sequence is valid according to the pattern
532
+ assert (
533
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
534
+ ).all()
535
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
536
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
537
+
538
+ # sanity checks over the returned codes and corresponding masks
539
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
540
+ assert (out_mask[..., :max_gen_len] == 1).all()
541
+
542
+ out_start_offset = start_offset if remove_prompts else 0
543
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
544
+
545
+ # ensure the returned codes are all valid
546
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
547
+ return out_codes
audiocraft/models/lm_magnet.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import math
9
+ import typing as tp
10
+ import torch
11
+ import numpy as np
12
+
13
+ from ..utils import utils
14
+ from ..modules.conditioners import (
15
+ ClassifierFreeGuidanceDropout,
16
+ ConditioningAttributes,
17
+ ConditionType,
18
+ )
19
+ from .lm import LMModel
20
+
21
+ logger = logging.getLogger(__name__)
22
+ ConditionTensors = tp.Dict[str, ConditionType]
23
+ CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
24
+
25
+
26
+ class MagnetLMModel(LMModel):
27
+ """Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT).
28
+ Args:
29
+ subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0.
30
+ When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5.
31
+ compression_model_framerate (int): frame rate of the audio tokenizer.
32
+ segment_duration (int): Sample length in seconds.
33
+ span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens,
34
+ for both training and inference. Defaults to 3.
35
+ **kwargs: Additional parameters for the LMModel.
36
+ """
37
+ def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50,
38
+ segment_duration: int = 10, span_len: int = 3, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.causal = kwargs['causal']
41
+ self.subcodes_context = subcodes_context
42
+ self.span_len = span_len
43
+ self._build_attn_masks(compression_model_framerate=compression_model_framerate,
44
+ segment_duration=segment_duration,
45
+ num_heads=kwargs['num_heads'],
46
+ device=kwargs['device'], dtype=kwargs['dtype'])
47
+
48
+ def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
49
+ """Creates a restricted attention mask (local attention map) where the context
50
+ is determined by self.subcodes_context.
51
+ Args:
52
+ seq_len (int): token sequence length.
53
+ device (torch.device): device of the output tensor.
54
+ dtype (torch.dtype): data type of the output tensor.
55
+ Returns:
56
+ torch.Tensor: The restricted attention mask.
57
+ """
58
+ # Return a context restricted non-causal att mask
59
+ queries_pos = torch.arange(seq_len, device=device).view(-1, 1)
60
+ keys_pos = torch.arange(seq_len, device=device).view(1, -1)
61
+
62
+ delta = queries_pos - keys_pos
63
+ valid = torch.abs(delta) <= self.subcodes_context
64
+ return torch.where(
65
+ valid,
66
+ torch.zeros([], device=device, dtype=dtype),
67
+ torch.full([], float('-inf'), device=device, dtype=dtype))
68
+
69
+ def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
70
+ device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
71
+ """Creates a restricted attention mask given the stage (codebook index).
72
+ Args:
73
+ stage (int): The codebook index. Takes values in [0, n_q].
74
+ seq_len (int): Token sequence length.
75
+ num_heads (int): Num transformer attention heads.
76
+ device (torch.device): device of the output tensor.
77
+ dtype (torch.dtype): data type of the output tensor.
78
+ Returns:
79
+ torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted.
80
+ """
81
+ sa_mask = None
82
+
83
+ if stage > 0 and self.subcodes_context > -1:
84
+ # parallel - non-causal - with restricted subcodes context
85
+ sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)
86
+
87
+ if sa_mask is not None:
88
+ # Repeat for each attention head
89
+ sa_mask = sa_mask.repeat((1, num_heads, 1, 1))
90
+
91
+ # align8 to enable memory efficient attention
92
+ MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
93
+ seq_len_aligned = \
94
+ int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR
95
+
96
+ sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
97
+ sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
98
+ sa_mask = sa_mask_aligned
99
+
100
+ return sa_mask
101
+
102
+ def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
103
+ device: torch.device, dtype: torch.dtype):
104
+ """Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
105
+ either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
106
+ Args:
107
+ compression_model_framerate (int): The frame rate of the tokenizer.
108
+ segment_duration (int): Sample length in seconds.
109
+ num_heads (int): Num transformer attention heads.
110
+ device (torch.device): device of the output tensor.
111
+ dtype (torch.dtype): data type of the output tensor.
112
+ """
113
+ seq_len = compression_model_framerate * segment_duration
114
+ self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
115
+ device, dtype) for stage in range(self.n_q)]
116
+
117
+ @torch.no_grad()
118
+ def generate(self,
119
+ prompt: tp.Optional[torch.Tensor] = None,
120
+ conditions: tp.List[ConditioningAttributes] = [],
121
+ num_samples: tp.Optional[int] = None,
122
+ max_gen_len: int = 256,
123
+ use_sampling: bool = True,
124
+ temp: float = 1.0,
125
+ top_k: int = 250,
126
+ top_p: float = 0.0,
127
+ cfg_coef: tp.Optional[float] = None,
128
+ two_step_cfg: tp.Optional[bool] = None,
129
+ remove_prompts: bool = False,
130
+ check: bool = False,
131
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
132
+ **kwargs) -> torch.Tensor:
133
+
134
+ assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
135
+ assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
136
+ assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
137
+ assert check is False, "MAGNeT currently doesn't support the check arg."
138
+ # Call the MAGNeT-specific generation method
139
+ return self._generate_magnet(prompt=prompt,
140
+ conditions=conditions,
141
+ num_samples=num_samples,
142
+ max_gen_len=max_gen_len,
143
+ use_sampling=use_sampling,
144
+ temp=temp,
145
+ top_k=top_k,
146
+ top_p=top_p,
147
+ callback=callback, **kwargs)
148
+
149
+ @torch.no_grad()
150
+ def _generate_magnet(self,
151
+ prompt: tp.Optional[torch.Tensor] = None,
152
+ conditions: tp.List[ConditioningAttributes] = [],
153
+ num_samples: tp.Optional[int] = None,
154
+ max_gen_len: int = 256,
155
+ use_sampling: bool = True,
156
+ temp: float = 3.0,
157
+ top_k: int = 0,
158
+ top_p: float = 0.9,
159
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None,
160
+ max_cfg_coef: float = 10.0,
161
+ min_cfg_coef: float = 1.0,
162
+ decoding_steps: tp.List[int] = [20, 10, 10, 10],
163
+ anneal_temp: bool = True,
164
+ span_scoring='max',
165
+ span_arrangement='nonoverlap') -> torch.Tensor:
166
+ """Generate audio tokens given textual conditions, and optionally given audio prompts,
167
+ by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels.
168
+ Args:
169
+ prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
170
+ conditions (list of ConditioningAttributes): List of conditions.
171
+ num_samples (int): Number of samples to generate when no prompt and no conditions are given.
172
+ max_gen_len (int): Maximum generation length.
173
+ use_sampling (bool): Whether to use a sampling strategy or not.
174
+ temp (float): Initial sampling temperature.
175
+ top_k (int): k for "top-k" sampling.
176
+ top_p (float): p for "top-p" sampling.
177
+ callback (Callback): Callback function to report generation progress.
178
+ max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
179
+ min_clsfg_coef (float): Final coefficient used for classifier free guidance.
180
+ decoding_steps (list of n_q ints): The number of iterative decoding steps,
181
+ for each of the n_q RVQ codebooks.
182
+ anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
183
+ span_scoring (str): Use the maximum probability of each span ('max')
184
+ or the product of probabilities ('prod').
185
+ span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
186
+ in the masking scheme.
187
+ Returns:
188
+ torch.Tensor: Generated tokens.
189
+ """
190
+ assert not self.training, "generation shouldn't be used in training mode."
191
+ first_param = next(iter(self.parameters()))
192
+ device = first_param.device
193
+
194
+ # Checking all input shapes are consistent.
195
+ possible_num_samples = []
196
+ if num_samples is not None:
197
+ possible_num_samples.append(num_samples)
198
+ elif prompt is not None:
199
+ possible_num_samples.append(prompt.shape[0])
200
+ elif conditions:
201
+ possible_num_samples.append(len(conditions))
202
+ else:
203
+ possible_num_samples.append(1)
204
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
205
+ num_samples = possible_num_samples[0]
206
+
207
+ # below we create set of conditions: one conditional and one unconditional
208
+ # to do that we merge the regular condition together with the null condition
209
+ # we then do 1 forward pass instead of 2.
210
+ cfg_conditions: tp.Optional[ConditionTensors]
211
+ if conditions:
212
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
213
+ conditions = conditions + null_conditions
214
+ tokenized = self.condition_provider.tokenize(conditions)
215
+ cfg_conditions = self.condition_provider(tokenized)
216
+ else:
217
+ cfg_conditions = {}
218
+
219
+ if prompt is None:
220
+ assert num_samples > 0
221
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
222
+
223
+ B, K, prompt_length = prompt.shape
224
+ start_offset = prompt_length
225
+ assert start_offset < max_gen_len
226
+
227
+ mask_id = self.special_token_id
228
+
229
+ # we generate codes with a fixed sequence length
230
+ shape = (B, K, max_gen_len)
231
+
232
+ gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device)
233
+ # filling the gen_codes with the prompt if needed
234
+ gen_codes[..., :start_offset] = prompt
235
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
236
+ gen_sequence = gen_codes
237
+
238
+ curr_step = 0
239
+ for stage, n_steps in zip(range(self.n_q), decoding_steps):
240
+ gen_sequence, curr_step = self._generate_stage(gen_sequence,
241
+ cfg_conditions,
242
+ stage=stage,
243
+ device=device,
244
+ prompt_length=prompt_length,
245
+ prompt=prompt,
246
+ temp=temp,
247
+ max_cfg_coef=max_cfg_coef,
248
+ min_cfg_coef=min_cfg_coef,
249
+ top_k=top_k,
250
+ top_p=top_p,
251
+ timesteps=n_steps,
252
+ anneal_temp=anneal_temp,
253
+ span_scoring=span_scoring,
254
+ use_sampling=use_sampling,
255
+ span_arrangement=span_arrangement,
256
+ curr_step=curr_step,
257
+ total_steps=sum(decoding_steps),
258
+ callback=callback)
259
+
260
+ return gen_sequence
261
+
262
+ @torch.no_grad()
263
+ def _generate_stage(self,
264
+ gen_sequence: torch.Tensor,
265
+ condition_tensors: tp.Optional[ConditionTensors],
266
+ stage: int,
267
+ device: torch.device,
268
+ prompt_length: int = 0,
269
+ prompt: tp.Optional[torch.Tensor] = None,
270
+ use_sampling: bool = True,
271
+ temp: float = 3.0,
272
+ max_cfg_coef: float = 10.0,
273
+ min_cfg_coef: float = 1.0,
274
+ top_k: int = 0,
275
+ top_p: float = 0.0,
276
+ timesteps: int = 10,
277
+ anneal_temp: bool = True,
278
+ span_scoring: str = 'max',
279
+ span_arrangement: str = 'nonoverlap',
280
+ curr_step: int = 0,
281
+ total_steps: int = 0,
282
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]:
283
+ """Generate audio tokens of a single RVQ level (stage), given the previously generated stages,
284
+ and the textual conditions.
285
+ Args:
286
+ gen_sequence (torch.Tensor): Previously generated tokens.
287
+ condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors.
288
+ stage (int): RVQ level to generate.
289
+ device (torch.device): device of the output tensor.
290
+ prompt_length (int): Temporal length of the audio prompt.
291
+ prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
292
+ use_sampling (bool): Whether to use a sampling strategy or not.
293
+ temp (float): Initial sampling temperature.
294
+ max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
295
+ min_clsfg_coef (float): Final coefficient used for classifier free guidance.
296
+ top_k (int): k for "top-k" sampling.
297
+ top_p (float): p for "top-p" sampling.
298
+ timesteps (int): Number of iterative decoding steps.
299
+ anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
300
+ span_scoring (str): Use the maximum probability of each span ('max')
301
+ or the product of probabilities ('prod').
302
+ span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
303
+ in the masking scheme.
304
+ curr_step (int): Global iterative decoding step counter.
305
+ total_steps (int): Total decoding steps.
306
+ callback (Callback): Callback function to report generation progress.
307
+ Returns:
308
+ tuple(torch.Tensor, int): Generated tokens and the current decoding step counter.
309
+ """
310
+ B, K, T = gen_sequence.shape
311
+ shape = (B, 1, T) # generating a single codebook per stage
312
+
313
+ mask_id = self.special_token_id
314
+ stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device)
315
+
316
+ assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1'
317
+ chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap'
318
+
319
+ DONT_REMASK_ME_SCORE = -1e4
320
+
321
+ model = self if self._fsdp is None else self._fsdp
322
+
323
+ if chunk_masking:
324
+ # span-wise scores
325
+ n_chunks = T // self.span_len
326
+ if T % self.span_len != 0:
327
+ # trim sequence ending to achieve a multiple of span_len
328
+ T = self.span_len * n_chunks
329
+ gen_sequence = gen_sequence[..., :T]
330
+ stage_gen_seq = stage_gen_seq[..., :T]
331
+
332
+ chunked_shape = (B, 1, n_chunks)
333
+ n_prompt_chunks = prompt_length // self.span_len
334
+ scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device)
335
+ scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE
336
+ num_chunks_to_gen = n_chunks - n_prompt_chunks
337
+ else:
338
+ # token-wise scores
339
+ scores = torch.zeros(shape, dtype=torch.float32, device=device)
340
+ scores[..., :prompt_length] = DONT_REMASK_ME_SCORE
341
+ gen_T = T - prompt_length
342
+
343
+ # run MAGNeT iterative decoding for "timesteps" iterations
344
+ for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
345
+
346
+ mask_p = torch.cos(timestep * math.pi * 0.5)
347
+
348
+ if chunk_masking:
349
+ num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1)
350
+ else:
351
+ num_masked = max(int((mask_p * gen_T).item()), 1)
352
+
353
+ # masking
354
+ run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1
355
+ if run_lps_masking:
356
+ # masking of the k least probable overlapping (stride 1) spans
357
+ mask = torch.concat((
358
+ [self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device)
359
+ for i in range(B)]), dim=0)
360
+ stage_gen_seq[mask] = mask_id
361
+ else:
362
+ # masking of the k least probable non-overlapping spans
363
+ masked = scores.topk(num_masked, dim=-1).indices
364
+ if chunk_masking:
365
+ chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device)
366
+ chunks_mask = chunks_mask.scatter(2, masked, True)
367
+ mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1)
368
+ stage_gen_seq[mask] = mask_id
369
+ else:
370
+ stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id)
371
+
372
+ if prompt is not None:
373
+ stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1)
374
+
375
+ gen_sequence[:, [stage], :] = stage_gen_seq
376
+ if condition_tensors:
377
+ # duplicate input for classifier free guidance
378
+ sequence = torch.cat([gen_sequence, gen_sequence], dim=0)
379
+
380
+ all_logits = model(sequence, [], condition_tensors, stage=stage)
381
+
382
+ if condition_tensors:
383
+ # classifier free guidance with annealing
384
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
385
+ clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef
386
+ logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef
387
+ else:
388
+ logits = all_logits
389
+
390
+ # temperature annealing - linear
391
+ t = temp * (steps_left / timesteps) if anneal_temp else temp
392
+
393
+ # sampling
394
+ logits = logits[:, stage, :, :].unsqueeze(1)
395
+ probs = torch.softmax(logits / max(t, 1e-2), dim=-1)
396
+ if use_sampling:
397
+ if top_p > 0.0:
398
+ sampled_tokens = utils.sample_top_p(probs, p=top_p)
399
+ elif top_k > 0:
400
+ sampled_tokens = utils.sample_top_k(probs, k=top_k)
401
+ else:
402
+ sampled_tokens = utils.multinomial(probs, num_samples=1)
403
+ else:
404
+ sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True)
405
+
406
+ # place mask_id token in each of the masked positions
407
+ mask = stage_gen_seq == mask_id
408
+ stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq)
409
+ gen_sequence[:, [stage], :] = stage_gen_seq
410
+
411
+ # get probs of sampled tokens
412
+ sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0]
413
+
414
+ # span scoring
415
+ if chunk_masking:
416
+ if span_scoring == 'max':
417
+ # max in linear space
418
+ scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0]
419
+ elif span_scoring == 'prod':
420
+ # prod in log space
421
+ scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1)
422
+ else:
423
+ raise NotImplementedError
424
+ else:
425
+ # prod in log space for lps masking (stride1)
426
+ scores = -torch.log(sampled_probs)
427
+
428
+ # Fix unmasked tokens by placing inf probs (-inf scores)
429
+ if chunk_masking:
430
+ scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE)
431
+ else:
432
+ scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE)
433
+
434
+ if callback is not None:
435
+ curr_step += 1
436
+ callback(curr_step, total_steps)
437
+
438
+ return gen_sequence, curr_step
439
+
440
+ def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor:
441
+ """Build a [1x1xT] boolean mask consists of overlapping spans of True values, where
442
+ span_starts defines the initial index of each span, and the span length is
443
+ defined by self.span_len.
444
+ Args:
445
+ span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start.
446
+ T (int): Sequence length.
447
+ device (torch.device): device of the output tensor.
448
+ Returns:
449
+ torch.Tensor: Spans mask of shape [1x1xT]
450
+ """
451
+ mask = torch.full((1, 1, T), False, device=device)
452
+ mask[:, :, span_starts] = True
453
+ shifted_mask = mask.clone()
454
+ for _ in range(self.span_len - 1):
455
+ shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1)
456
+ mask = torch.logical_or(mask, shifted_mask)
457
+ return mask
458
+
459
+ def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor:
460
+ """Construct a [1x1xT] boolean mask, consists of the u least probable spans,
461
+ where the token probability is determined by -scores, and the total
462
+ number of masked tokens is as closest as possible to num_masked_trg.
463
+ Find u using binary search.
464
+ Args:
465
+ scores (torch.Tensor): Per token score [-log(prob)]
466
+ num_masked_trg: int: The desired amount of tokens to be masked.
467
+ Returns:
468
+ torch.Tensor: Spans mask of shape [1x1xT]
469
+ """
470
+ T = scores.shape[-1]
471
+ device = scores.device
472
+ scores_unfolded = scores.unfold(2, self.span_len, 1)
473
+ # Span score is the product of probs (sum in log space)
474
+ span_scores = scores_unfolded.sum(dim=-1)
475
+ spans_by_scores = torch.argsort(span_scores[0, 0], descending=True)
476
+
477
+ num_masked_trg = max(num_masked_trg, self.span_len)
478
+
479
+ # Binary search for u - the number least probable overlapping masked spans s.t.
480
+ # the total masking rate is the closest to num_masked_trg / T.
481
+ min_u = num_masked_trg // self.span_len
482
+ max_u = num_masked_trg - self.span_len + 1
483
+ mid = round(0.5 * (min_u + max_u))
484
+
485
+ if mid == min_u or mid == max_u:
486
+ return self._construct_spans_mask(spans_by_scores[:mid], T, device)
487
+
488
+ while mid > min_u and mid < max_u:
489
+ mask = self._construct_spans_mask(spans_by_scores[:mid], T, device)
490
+ n_masked = mask.sum()
491
+ if n_masked > num_masked_trg:
492
+ max_u = mid
493
+ mid = round(0.5 * (min_u + max_u))
494
+ else:
495
+ min_u = mid
496
+ mid = round(0.5 * (min_u + max_u))
497
+
498
+ return mask
audiocraft/models/loaders.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Utility functions to load from the checkpoints.
9
+ Each checkpoint is a torch.saved dict with the following keys:
10
+ - 'xp.cfg': the hydra config as dumped during training. This should be used
11
+ to rebuild the object using the audiocraft.models.builders functions,
12
+ - 'model_best_state': a readily loadable best state for the model, including
13
+ the conditioner. The model obtained from `xp.cfg` should be compatible
14
+ with this state dict. In the case of a LM, the encodec model would not be
15
+ bundled along but instead provided separately.
16
+
17
+ Those functions also support loading from a remote location with the Torch Hub API.
18
+ They also support overriding some parameters, in particular the device and dtype
19
+ of the returned model.
20
+ """
21
+
22
+ from pathlib import Path
23
+ from huggingface_hub import hf_hub_download
24
+ import typing as tp
25
+ import os
26
+
27
+ from omegaconf import OmegaConf, DictConfig
28
+ import torch
29
+
30
+ import audiocraft
31
+ from . import builders
32
+ from .encodec import CompressionModel
33
+
34
+
35
+ def get_audiocraft_cache_dir() -> tp.Optional[str]:
36
+ return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
37
+
38
+
39
+ def _get_state_dict(
40
+ file_or_url_or_id: tp.Union[Path, str],
41
+ filename: tp.Optional[str] = None,
42
+ device='cpu',
43
+ cache_dir: tp.Optional[str] = None,
44
+ ):
45
+ if cache_dir is None:
46
+ cache_dir = get_audiocraft_cache_dir()
47
+ # Return the state dict either from a file or url
48
+ file_or_url_or_id = str(file_or_url_or_id)
49
+ assert isinstance(file_or_url_or_id, str)
50
+
51
+ if os.path.isfile(file_or_url_or_id):
52
+ return torch.load(file_or_url_or_id, map_location=device)
53
+
54
+ if os.path.isdir(file_or_url_or_id):
55
+ file = f"{file_or_url_or_id}/{filename}"
56
+ return torch.load(file, map_location=device)
57
+
58
+ elif file_or_url_or_id.startswith('https://'):
59
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
60
+
61
+ else:
62
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
63
+
64
+ file = hf_hub_download(
65
+ repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
66
+ library_name="audiocraft", library_version=audiocraft.__version__)
67
+ return torch.load(file, map_location=device)
68
+
69
+
70
+ def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
71
+ return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
72
+
73
+
74
+ def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
75
+ pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
76
+ if 'pretrained' in pkg:
77
+ return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
78
+ cfg = OmegaConf.create(pkg['xp.cfg'])
79
+ cfg.device = str(device)
80
+ model = builders.get_compression_model(cfg)
81
+ model.load_state_dict(pkg['best_state'])
82
+ model.eval()
83
+ return model
84
+
85
+
86
+ def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
87
+ return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
88
+
89
+
90
+ def _delete_param(cfg: DictConfig, full_name: str):
91
+ parts = full_name.split('.')
92
+ for part in parts[:-1]:
93
+ if part in cfg:
94
+ cfg = cfg[part]
95
+ else:
96
+ return
97
+ OmegaConf.set_struct(cfg, False)
98
+ if parts[-1] in cfg:
99
+ del cfg[parts[-1]]
100
+ OmegaConf.set_struct(cfg, True)
101
+
102
+
103
+ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
104
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
105
+ cfg = OmegaConf.create(pkg['xp.cfg'])
106
+ cfg.device = str(device)
107
+ if cfg.device == 'cpu':
108
+ cfg.dtype = 'float32'
109
+ else:
110
+ cfg.dtype = 'float16'
111
+ _delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
112
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
113
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
114
+ model = builders.get_lm_model(cfg)
115
+ model.load_state_dict(pkg['best_state'])
116
+ model.eval()
117
+ model.cfg = cfg
118
+ return model
119
+
120
+
121
+ def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
122
+ device='cpu', cache_dir: tp.Optional[str] = None):
123
+ pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
124
+ cfg = OmegaConf.create(pkg['xp.cfg'])
125
+ cfg.device = str(device)
126
+ if cfg.device == 'cpu':
127
+ cfg.dtype = 'float32'
128
+ else:
129
+ cfg.dtype = 'float16'
130
+ _delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
131
+ _delete_param(cfg, 'conditioners.args.drop_desc_p')
132
+
133
+ cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
134
+ cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
135
+ cfg.transformer_lm.span_len = cfg.masking.span_len
136
+
137
+ # MAGNeT models v1 support only xformers backend.
138
+ from audiocraft.modules.transformer import set_efficient_attention_backend
139
+ if cfg.transformer_lm.memory_efficient:
140
+ set_efficient_attention_backend("xformers")
141
+
142
+ model = builders.get_lm_model(cfg)
143
+ model.load_state_dict(pkg['best_state'])
144
+ model.eval()
145
+ model.cfg = cfg
146
+ return model
147
+
148
+
149
+ def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
150
+ filename: tp.Optional[str] = None,
151
+ cache_dir: tp.Optional[str] = None):
152
+ return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
153
+
154
+
155
+ def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
156
+ device='cpu',
157
+ filename: tp.Optional[str] = None,
158
+ cache_dir: tp.Optional[str] = None):
159
+ pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
160
+ models = []
161
+ processors = []
162
+ cfgs = []
163
+ sample_rate = pkg['sample_rate']
164
+ for i in range(pkg['n_bands']):
165
+ cfg = pkg[i]['cfg']
166
+ model = builders.get_diffusion_model(cfg)
167
+ model_dict = pkg[i]['model_state']
168
+ model.load_state_dict(model_dict)
169
+ model.to(device)
170
+ processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
171
+ processor_dict = pkg[i]['processor_state']
172
+ processor.load_state_dict(processor_dict)
173
+ processor.to(device)
174
+ models.append(model)
175
+ processors.append(processor)
176
+ cfgs.append(cfg)
177
+ return models, processors, cfgs
audiocraft/models/magnet.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Main model for using MAGNeT. This will combine all the required components
9
+ and provide easy access to the generation API.
10
+ """
11
+ import typing as tp
12
+ import torch
13
+
14
+ from .genmodel import BaseGenModel
15
+ from .loaders import load_compression_model, load_lm_model_magnet
16
+
17
+
18
+ class MAGNeT(BaseGenModel):
19
+ """MAGNeT main model with convenient generation API.
20
+ Args:
21
+ See MusicGen class.
22
+ """
23
+ def __init__(self, **kwargs):
24
+ super().__init__(**kwargs)
25
+ # MAGNeT operates over a fixed sequence length defined in it's config.
26
+ self.duration = self.lm.cfg.dataset.segment_duration
27
+ self.set_generation_params()
28
+
29
+ @staticmethod
30
+ def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
31
+ """Return pretrained model, we provide six models:
32
+ - facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
33
+ # see: https://huggingface.co/facebook/magnet-small-10secs
34
+ - facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
35
+ # see: https://huggingface.co/facebook/magnet-medium-10secs
36
+ - facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
37
+ # see: https://huggingface.co/facebook/magnet-small-30secs
38
+ - facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
39
+ # see: https://huggingface.co/facebook/magnet-medium-30secs
40
+ - facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
41
+ # see: https://huggingface.co/facebook/audio-magnet-small
42
+ - facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
43
+ # see: https://huggingface.co/facebook/audio-magnet-medium
44
+ """
45
+ if device is None:
46
+ if torch.cuda.device_count():
47
+ device = 'cuda'
48
+ else:
49
+ device = 'cpu'
50
+
51
+ compression_model = load_compression_model(name, device=device)
52
+ lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)
53
+
54
+ if 'self_wav' in lm.condition_provider.conditioners:
55
+ lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
56
+
57
+ kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
58
+ return MAGNeT(**kwargs)
59
+
60
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
61
+ top_p: float = 0.9, temperature: float = 3.0,
62
+ max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
63
+ decoding_steps: tp.List[int] = [20, 10, 10, 10],
64
+ span_arrangement: str = 'nonoverlap'):
65
+ """Set the generation parameters for MAGNeT.
66
+
67
+ Args:
68
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
69
+ top_k (int, optional): top_k used for sampling. Defaults to 0.
70
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
71
+ temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
72
+ max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
73
+ min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
74
+ decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
75
+ for each of the n_q RVQ codebooks.
76
+ span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
77
+ or overlapping spans ('stride1') in the masking scheme.
78
+ """
79
+ self.generation_params = {
80
+ 'use_sampling': use_sampling,
81
+ 'temp': temperature,
82
+ 'top_k': top_k,
83
+ 'top_p': top_p,
84
+ 'max_cfg_coef': max_cfg_coef,
85
+ 'min_cfg_coef': min_cfg_coef,
86
+ 'decoding_steps': [int(s) for s in decoding_steps],
87
+ 'span_arrangement': span_arrangement
88
+ }
audiocraft/models/multibanddiffusion.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Multi Band Diffusion models as described in
9
+ "From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
10
+ (paper link).
11
+ """
12
+
13
+ import typing as tp
14
+
15
+ import torch
16
+ import julius
17
+
18
+ from .unet import DiffusionUnet
19
+ from ..modules.diffusion_schedule import NoiseSchedule
20
+ from .encodec import CompressionModel
21
+ from ..solvers.compression import CompressionSolver
22
+ from .loaders import load_compression_model, load_diffusion_models
23
+
24
+
25
+ class DiffusionProcess:
26
+ """Sampling for a diffusion Model.
27
+
28
+ Args:
29
+ model (DiffusionUnet): Diffusion U-Net model.
30
+ noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
31
+ """
32
+ def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
33
+ self.model = model
34
+ self.schedule = noise_schedule
35
+
36
+ def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
37
+ step_list: tp.Optional[tp.List[int]] = None):
38
+ """Perform one diffusion process to generate one of the bands.
39
+
40
+ Args:
41
+ condition (torch.Tensor): The embeddings from the compression model.
42
+ initial_noise (torch.Tensor): The initial noise to start the process.
43
+ """
44
+ return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
45
+ condition=condition)
46
+
47
+
48
+ class MultiBandDiffusion:
49
+ """Sample from multiple diffusion models.
50
+
51
+ Args:
52
+ DPs (list of DiffusionProcess): Diffusion processes.
53
+ codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
54
+ """
55
+ def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
56
+ self.DPs = DPs
57
+ self.codec_model = codec_model
58
+ self.device = next(self.codec_model.parameters()).device
59
+
60
+ @property
61
+ def sample_rate(self) -> int:
62
+ return self.codec_model.sample_rate
63
+
64
+ @staticmethod
65
+ def get_mbd_musicgen(device=None):
66
+ """Load our diffusion models trained for MusicGen."""
67
+ if device is None:
68
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
69
+ path = 'facebook/multiband-diffusion'
70
+ filename = 'mbd_musicgen_32khz.th'
71
+ name = 'facebook/musicgen-small'
72
+ codec_model = load_compression_model(name, device=device)
73
+ models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
74
+ DPs = []
75
+ for i in range(len(models)):
76
+ schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
77
+ DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
78
+ return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
79
+
80
+ @staticmethod
81
+ def get_mbd_24khz(bw: float = 3.0,
82
+ device: tp.Optional[tp.Union[torch.device, str]] = None,
83
+ n_q: tp.Optional[int] = None):
84
+ """Get the pretrained Models for MultibandDiffusion.
85
+
86
+ Args:
87
+ bw (float): Bandwidth of the compression model.
88
+ device (torch.device or str, optional): Device on which the models are loaded.
89
+ n_q (int, optional): Number of quantizers to use within the compression model.
90
+ """
91
+ if device is None:
92
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
93
+ assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
94
+ if n_q is not None:
95
+ assert n_q in [2, 4, 8]
96
+ assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
97
+ f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
98
+ n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
99
+ codec_model = CompressionSolver.model_from_checkpoint(
100
+ '//pretrained/facebook/encodec_24khz', device=device)
101
+ codec_model.set_num_codebooks(n_q)
102
+ codec_model = codec_model.to(device)
103
+ path = 'facebook/multiband-diffusion'
104
+ filename = f'mbd_comp_{n_q}.pt'
105
+ models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
106
+ DPs = []
107
+ for i in range(len(models)):
108
+ schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
109
+ DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
110
+ return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
111
+
112
+ @torch.no_grad()
113
+ def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
114
+ """Get the conditioning (i.e. latent representations of the compression model) from a waveform.
115
+ Args:
116
+ wav (torch.Tensor): The audio that we want to extract the conditioning from.
117
+ sample_rate (int): Sample rate of the audio."""
118
+ if sample_rate != self.sample_rate:
119
+ wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
120
+ codes, scale = self.codec_model.encode(wav)
121
+ assert scale is None, "Scaled compression models not supported."
122
+ emb = self.get_emb(codes)
123
+ return emb
124
+
125
+ @torch.no_grad()
126
+ def get_emb(self, codes: torch.Tensor):
127
+ """Get latent representation from the discrete codes.
128
+ Args:
129
+ codes (torch.Tensor): Discrete tokens."""
130
+ emb = self.codec_model.decode_latent(codes)
131
+ return emb
132
+
133
+ def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
134
+ step_list: tp.Optional[tp.List[int]] = None):
135
+ """Generate waveform audio from the latent embeddings of the compression model.
136
+ Args:
137
+ emb (torch.Tensor): Conditioning embeddings
138
+ size (None, torch.Size): Size of the output
139
+ if None this is computed from the typical upsampling of the model.
140
+ step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
141
+ """
142
+ if size is None:
143
+ upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
144
+ size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
145
+ assert size[0] == emb.size(0)
146
+ out = torch.zeros(size).to(self.device)
147
+ for DP in self.DPs:
148
+ out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
149
+ return out
150
+
151
+ def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
152
+ """Match the eq to the encodec output by matching the standard deviation of some frequency bands.
153
+ Args:
154
+ wav (torch.Tensor): Audio to equalize.
155
+ ref (torch.Tensor): Reference audio from which we match the spectrogram.
156
+ n_bands (int): Number of bands of the eq.
157
+ strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
158
+ """
159
+ split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
160
+ bands = split(wav)
161
+ bands_ref = split(ref)
162
+ out = torch.zeros_like(ref)
163
+ for i in range(n_bands):
164
+ out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
165
+ return out
166
+
167
+ def regenerate(self, wav: torch.Tensor, sample_rate: int):
168
+ """Regenerate a waveform through compression and diffusion regeneration.
169
+ Args:
170
+ wav (torch.Tensor): Original 'ground truth' audio.
171
+ sample_rate (int): Sample rate of the input (and output) wav.
172
+ """
173
+ if sample_rate != self.codec_model.sample_rate:
174
+ wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
175
+ emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
176
+ size = wav.size()
177
+ out = self.generate(emb, size=size)
178
+ if sample_rate != self.codec_model.sample_rate:
179
+ out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
180
+ return out
181
+
182
+ def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
183
+ """Generate Waveform audio with diffusion from the discrete codes.
184
+ Args:
185
+ tokens (torch.Tensor): Discrete codes.
186
+ n_bands (int): Bands for the eq matching.
187
+ """
188
+ wav_encodec = self.codec_model.decode(tokens)
189
+ condition = self.get_emb(tokens)
190
+ wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
191
+ return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
audiocraft/models/musicgen.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Main model for using MusicGen. This will combine all the required components
9
+ and provide easy access to the generation API.
10
+ """
11
+
12
+ import typing as tp
13
+ import warnings
14
+
15
+ import torch
16
+
17
+ from .encodec import CompressionModel
18
+ from .genmodel import BaseGenModel
19
+ from .lm import LMModel
20
+ from .builders import get_debug_compression_model, get_debug_lm_model
21
+ from .loaders import load_compression_model, load_lm_model
22
+ from ..data.audio_utils import convert_audio
23
+ from ..modules.conditioners import ConditioningAttributes, WavCondition
24
+
25
+
26
+ MelodyList = tp.List[tp.Optional[torch.Tensor]]
27
+ MelodyType = tp.Union[torch.Tensor, MelodyList]
28
+
29
+
30
+ # backward compatible names mapping
31
+ _HF_MODEL_CHECKPOINTS_MAP = {
32
+ "small": "facebook/musicgen-small",
33
+ "medium": "facebook/musicgen-medium",
34
+ "large": "facebook/musicgen-large",
35
+ "melody": "facebook/musicgen-melody",
36
+ }
37
+
38
+
39
+ class MusicGen(BaseGenModel):
40
+ """MusicGen main model with convenient generation API.
41
+
42
+ Args:
43
+ name (str): name of the model.
44
+ compression_model (CompressionModel): Compression model
45
+ used to map audio to invertible discrete representations.
46
+ lm (LMModel): Language model over discrete representations.
47
+ max_duration (float, optional): maximum duration the model can produce,
48
+ otherwise, inferred from the training params.
49
+ """
50
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
51
+ max_duration: tp.Optional[float] = None):
52
+ super().__init__(name, compression_model, lm, max_duration)
53
+ self.set_generation_params(duration=15) # default duration
54
+
55
+ @staticmethod
56
+ def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
57
+ """Return pretrained model, we provide four models:
58
+ - facebook/musicgen-small (300M), text to music,
59
+ # see: https://huggingface.co/facebook/musicgen-small
60
+ - facebook/musicgen-medium (1.5B), text to music,
61
+ # see: https://huggingface.co/facebook/musicgen-medium
62
+ - facebook/musicgen-melody (1.5B) text to music and text+melody to music,
63
+ # see: https://huggingface.co/facebook/musicgen-melody
64
+ - facebook/musicgen-large (3.3B), text to music,
65
+ # see: https://huggingface.co/facebook/musicgen-large
66
+ """
67
+ if device is None:
68
+ if torch.cuda.device_count():
69
+ device = 'cuda'
70
+ else:
71
+ device = 'cpu'
72
+
73
+ if name == 'debug':
74
+ # used only for unit tests
75
+ compression_model = get_debug_compression_model(device)
76
+ lm = get_debug_lm_model(device)
77
+ return MusicGen(name, compression_model, lm, max_duration=30)
78
+
79
+ if name in _HF_MODEL_CHECKPOINTS_MAP:
80
+ warnings.warn(
81
+ "MusicGen pretrained model relying on deprecated checkpoint mapping. " +
82
+ f"Please use full pre-trained id instead: facebook/musicgen-{name}")
83
+ name = _HF_MODEL_CHECKPOINTS_MAP[name]
84
+
85
+ lm = load_lm_model(name, device=device)
86
+ compression_model = load_compression_model(name, device=device)
87
+ if 'self_wav' in lm.condition_provider.conditioners:
88
+ lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
89
+ lm.condition_provider.conditioners['self_wav']._use_masking = False
90
+
91
+ return MusicGen(name, compression_model, lm)
92
+
93
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
94
+ top_p: float = 0.0, temperature: float = 1.0,
95
+ duration: float = 30.0, cfg_coef: float = 3.0,
96
+ two_step_cfg: bool = False, extend_stride: float = 18):
97
+ """Set the generation parameters for MusicGen.
98
+
99
+ Args:
100
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
101
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
102
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
103
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
104
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
105
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
106
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
107
+ instead of batching together the two. This has some impact on how things
108
+ are padded but seems to have little impact in practice.
109
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
110
+ should we extend the audio each time. Larger values will mean less context is
111
+ preserved, and shorter value will require extra computations.
112
+ """
113
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
114
+ self.extend_stride = extend_stride
115
+ self.duration = duration
116
+ self.generation_params = {
117
+ 'use_sampling': use_sampling,
118
+ 'temp': temperature,
119
+ 'top_k': top_k,
120
+ 'top_p': top_p,
121
+ 'cfg_coef': cfg_coef,
122
+ 'two_step_cfg': two_step_cfg,
123
+ }
124
+
125
+ def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
126
+ melody_sample_rate: int, progress: bool = False,
127
+ return_tokens: bool = False) -> tp.Union[torch.Tensor,
128
+ tp.Tuple[torch.Tensor, torch.Tensor]]:
129
+ """Generate samples conditioned on text and melody.
130
+
131
+ Args:
132
+ descriptions (list of str): A list of strings used as text conditioning.
133
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
134
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
135
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
136
+ a list of [C, T] tensors.
137
+ melody_sample_rate: (int): Sample rate of the melody waveforms.
138
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
139
+ """
140
+ if isinstance(melody_wavs, torch.Tensor):
141
+ if melody_wavs.dim() == 2:
142
+ melody_wavs = melody_wavs[None]
143
+ if melody_wavs.dim() != 3:
144
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
145
+ melody_wavs = list(melody_wavs)
146
+ else:
147
+ for melody in melody_wavs:
148
+ if melody is not None:
149
+ assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
150
+
151
+ melody_wavs = [
152
+ convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
153
+ if wav is not None else None
154
+ for wav in melody_wavs]
155
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
156
+ melody_wavs=melody_wavs)
157
+ assert prompt_tokens is None
158
+ tokens = self._generate_tokens(attributes, prompt_tokens, progress)
159
+ if return_tokens:
160
+ return self.generate_audio(tokens), tokens
161
+ return self.generate_audio(tokens)
162
+
163
+ @torch.no_grad()
164
+ def _prepare_tokens_and_attributes(
165
+ self,
166
+ descriptions: tp.Sequence[tp.Optional[str]],
167
+ prompt: tp.Optional[torch.Tensor],
168
+ melody_wavs: tp.Optional[MelodyList] = None,
169
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
170
+ """Prepare model inputs.
171
+
172
+ Args:
173
+ descriptions (list of str): A list of strings used as text conditioning.
174
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
175
+ melody_wavs (torch.Tensor, optional): A batch of waveforms
176
+ used as melody conditioning. Defaults to None.
177
+ """
178
+ attributes = [
179
+ ConditioningAttributes(text={'description': description})
180
+ for description in descriptions]
181
+
182
+ if melody_wavs is None:
183
+ for attr in attributes:
184
+ attr.wav['self_wav'] = WavCondition(
185
+ torch.zeros((1, 1, 1), device=self.device),
186
+ torch.tensor([0], device=self.device),
187
+ sample_rate=[self.sample_rate],
188
+ path=[None])
189
+ else:
190
+ if 'self_wav' not in self.lm.condition_provider.conditioners:
191
+ raise RuntimeError("This model doesn't support melody conditioning. "
192
+ "Use the `melody` model.")
193
+ assert len(melody_wavs) == len(descriptions), \
194
+ f"number of melody wavs must match number of descriptions! " \
195
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
196
+ for attr, melody in zip(attributes, melody_wavs):
197
+ if melody is None:
198
+ attr.wav['self_wav'] = WavCondition(
199
+ torch.zeros((1, 1, 1), device=self.device),
200
+ torch.tensor([0], device=self.device),
201
+ sample_rate=[self.sample_rate],
202
+ path=[None])
203
+ else:
204
+ attr.wav['self_wav'] = WavCondition(
205
+ melody[None].to(device=self.device),
206
+ torch.tensor([melody.shape[-1]], device=self.device),
207
+ sample_rate=[self.sample_rate],
208
+ path=[None],
209
+ )
210
+
211
+ if prompt is not None:
212
+ if descriptions is not None:
213
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
214
+ prompt = prompt.to(self.device)
215
+ prompt_tokens, scale = self.compression_model.encode(prompt)
216
+ assert scale is None
217
+ else:
218
+ prompt_tokens = None
219
+ return attributes, prompt_tokens
220
+
221
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
222
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
223
+ """Generate discrete audio tokens given audio prompt and/or conditions.
224
+
225
+ Args:
226
+ attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
227
+ prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
228
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
229
+ Returns:
230
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
231
+ """
232
+ total_gen_len = int(self.duration * self.frame_rate)
233
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
234
+ current_gen_offset: int = 0
235
+
236
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
237
+ generated_tokens += current_gen_offset
238
+ if self._progress_callback is not None:
239
+ # Note that total_gen_len might be quite wrong depending on the
240
+ # codebook pattern used, but with delay it is almost accurate.
241
+ self._progress_callback(generated_tokens, tokens_to_generate)
242
+ else:
243
+ print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
244
+
245
+ if prompt_tokens is not None:
246
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
247
+ "Prompt is longer than audio to generate"
248
+
249
+ callback = None
250
+ if progress:
251
+ callback = _progress_callback
252
+
253
+ if self.duration <= self.max_duration:
254
+ # generate by sampling from LM, simple case.
255
+ with self.autocast:
256
+ gen_tokens = self.lm.generate(
257
+ prompt_tokens, attributes,
258
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
259
+
260
+ else:
261
+ # now this gets a bit messier, we need to handle prompts,
262
+ # melody conditioning etc.
263
+ ref_wavs = [attr.wav['self_wav'] for attr in attributes]
264
+ all_tokens = []
265
+ if prompt_tokens is None:
266
+ prompt_length = 0
267
+ else:
268
+ all_tokens.append(prompt_tokens)
269
+ prompt_length = prompt_tokens.shape[-1]
270
+
271
+ assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
272
+ assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
273
+ stride_tokens = int(self.frame_rate * self.extend_stride)
274
+
275
+ while current_gen_offset + prompt_length < total_gen_len:
276
+ time_offset = current_gen_offset / self.frame_rate
277
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
278
+ max_gen_len = int(chunk_duration * self.frame_rate)
279
+ for attr, ref_wav in zip(attributes, ref_wavs):
280
+ wav_length = ref_wav.length.item()
281
+ if wav_length == 0:
282
+ continue
283
+ # We will extend the wav periodically if it not long enough.
284
+ # we have to do it here rather than in conditioners.py as otherwise
285
+ # we wouldn't have the full wav.
286
+ initial_position = int(time_offset * self.sample_rate)
287
+ wav_target_length = int(self.max_duration * self.sample_rate)
288
+ positions = torch.arange(initial_position,
289
+ initial_position + wav_target_length, device=self.device)
290
+ attr.wav['self_wav'] = WavCondition(
291
+ ref_wav[0][..., positions % wav_length],
292
+ torch.full_like(ref_wav[1], wav_target_length),
293
+ [self.sample_rate] * ref_wav[0].size(0),
294
+ [None], [0.])
295
+ with self.autocast:
296
+ gen_tokens = self.lm.generate(
297
+ prompt_tokens, attributes,
298
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
299
+ if prompt_tokens is None:
300
+ all_tokens.append(gen_tokens)
301
+ else:
302
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
303
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
304
+ prompt_length = prompt_tokens.shape[-1]
305
+ current_gen_offset += stride_tokens
306
+
307
+ gen_tokens = torch.cat(all_tokens, dim=-1)
308
+ return gen_tokens
audiocraft/models/unet.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Pytorch Unet Module used for diffusion.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ import typing as tp
13
+
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+ from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
18
+
19
+
20
+ @dataclass
21
+ class Output:
22
+ sample: torch.Tensor
23
+
24
+
25
+ def get_model(cfg, channels: int, side: int, num_steps: int):
26
+ if cfg.model == 'unet':
27
+ return DiffusionUnet(
28
+ chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
29
+ else:
30
+ raise RuntimeError('Not Implemented')
31
+
32
+
33
+ class ResBlock(nn.Module):
34
+ def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
35
+ dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
36
+ dropout: float = 0.):
37
+ super().__init__()
38
+ stride = 1
39
+ padding = dilation * (kernel - stride) // 2
40
+ Conv = nn.Conv1d
41
+ Drop = nn.Dropout1d
42
+ self.norm1 = nn.GroupNorm(norm_groups, channels)
43
+ self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
44
+ self.activation1 = activation()
45
+ self.dropout1 = Drop(dropout)
46
+
47
+ self.norm2 = nn.GroupNorm(norm_groups, channels)
48
+ self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
49
+ self.activation2 = activation()
50
+ self.dropout2 = Drop(dropout)
51
+
52
+ def forward(self, x):
53
+ h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
54
+ h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
55
+ return x + h
56
+
57
+
58
+ class DecoderLayer(nn.Module):
59
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
60
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
61
+ dropout: float = 0.):
62
+ super().__init__()
63
+ padding = (kernel - stride) // 2
64
+ self.res_blocks = nn.Sequential(
65
+ *[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
66
+ for idx in range(res_blocks)])
67
+ self.norm = nn.GroupNorm(norm_groups, chin)
68
+ ConvTr = nn.ConvTranspose1d
69
+ self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
70
+ self.activation = activation()
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ x = self.res_blocks(x)
74
+ x = self.norm(x)
75
+ x = self.activation(x)
76
+ x = self.convtr(x)
77
+ return x
78
+
79
+
80
+ class EncoderLayer(nn.Module):
81
+ def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
82
+ norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
83
+ dropout: float = 0.):
84
+ super().__init__()
85
+ padding = (kernel - stride) // 2
86
+ Conv = nn.Conv1d
87
+ self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
88
+ self.norm = nn.GroupNorm(norm_groups, chout)
89
+ self.activation = activation()
90
+ self.res_blocks = nn.Sequential(
91
+ *[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
92
+ for idx in range(res_blocks)])
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ B, C, T = x.shape
96
+ stride, = self.conv.stride
97
+ pad = (stride - (T % stride)) % stride
98
+ x = F.pad(x, (0, pad))
99
+
100
+ x = self.conv(x)
101
+ x = self.norm(x)
102
+ x = self.activation(x)
103
+ x = self.res_blocks(x)
104
+ return x
105
+
106
+
107
+ class BLSTM(nn.Module):
108
+ """BiLSTM with same hidden units as input dim.
109
+ """
110
+ def __init__(self, dim, layers=2):
111
+ super().__init__()
112
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
113
+ self.linear = nn.Linear(2 * dim, dim)
114
+
115
+ def forward(self, x):
116
+ x = x.permute(2, 0, 1)
117
+ x = self.lstm(x)[0]
118
+ x = self.linear(x)
119
+ x = x.permute(1, 2, 0)
120
+ return x
121
+
122
+
123
+ class DiffusionUnet(nn.Module):
124
+ def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
125
+ max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
126
+ bilstm: bool = False, transformer: bool = False,
127
+ codec_dim: tp.Optional[int] = None, **kwargs):
128
+ super().__init__()
129
+ self.encoders = nn.ModuleList()
130
+ self.decoders = nn.ModuleList()
131
+ self.embeddings: tp.Optional[nn.ModuleList] = None
132
+ self.embedding = nn.Embedding(num_steps, hidden)
133
+ if emb_all_layers:
134
+ self.embeddings = nn.ModuleList()
135
+ self.condition_embedding: tp.Optional[nn.Module] = None
136
+ for d in range(depth):
137
+ encoder = EncoderLayer(chin, hidden, **kwargs)
138
+ decoder = DecoderLayer(hidden, chin, **kwargs)
139
+ self.encoders.append(encoder)
140
+ self.decoders.insert(0, decoder)
141
+ if emb_all_layers and d > 0:
142
+ assert self.embeddings is not None
143
+ self.embeddings.append(nn.Embedding(num_steps, hidden))
144
+ chin = hidden
145
+ hidden = min(int(chin * growth), max_channels)
146
+ self.bilstm: tp.Optional[nn.Module]
147
+ if bilstm:
148
+ self.bilstm = BLSTM(chin)
149
+ else:
150
+ self.bilstm = None
151
+ self.use_transformer = transformer
152
+ self.cross_attention = False
153
+ if transformer:
154
+ self.cross_attention = cross_attention
155
+ self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
156
+ cross_attention=cross_attention)
157
+
158
+ self.use_codec = False
159
+ if codec_dim is not None:
160
+ self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
161
+ self.use_codec = True
162
+
163
+ def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
164
+ skips = []
165
+ bs = x.size(0)
166
+ z = x
167
+ view_args = [1]
168
+ if type(step) is torch.Tensor:
169
+ step_tensor = step
170
+ else:
171
+ step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
172
+
173
+ for idx, encoder in enumerate(self.encoders):
174
+ z = encoder(z)
175
+ if idx == 0:
176
+ z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
177
+ elif self.embeddings is not None:
178
+ z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
179
+
180
+ skips.append(z)
181
+
182
+ if self.use_codec: # insert condition in the bottleneck
183
+ assert condition is not None, "Model defined for conditionnal generation"
184
+ condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
185
+ assert condition_emb.size(-1) <= 2 * z.size(-1), \
186
+ f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
187
+ if not self.cross_attention:
188
+
189
+ condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
190
+ assert z.size() == condition_emb.size()
191
+ z += condition_emb
192
+ cross_attention_src = None
193
+ else:
194
+ cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
195
+ B, T, C = cross_attention_src.shape
196
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
197
+ pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
198
+ cross_attention_src = cross_attention_src + pos_emb
199
+ if self.use_transformer:
200
+ z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
201
+ else:
202
+ if self.bilstm is None:
203
+ z = torch.zeros_like(z)
204
+ else:
205
+ z = self.bilstm(z)
206
+
207
+ for decoder in self.decoders:
208
+ s = skips.pop(-1)
209
+ z = z[:, :, :s.shape[2]]
210
+ z = z + s
211
+ z = decoder(z)
212
+
213
+ z = z[:, :, :x.shape[2]]
214
+ return Output(z)