Spaces:
Sleeping
Sleeping
unpairedelectron07
commited on
Upload 11 files
Browse files- audiocraft/models/audiogen.py +93 -0
- audiocraft/models/builders.py +261 -0
- audiocraft/models/encodec.py +506 -0
- audiocraft/models/genmodel.py +267 -0
- audiocraft/models/lm.py +547 -0
- audiocraft/models/lm_magnet.py +498 -0
- audiocraft/models/loaders.py +177 -0
- audiocraft/models/magnet.py +88 -0
- audiocraft/models/multibanddiffusion.py +191 -0
- audiocraft/models/musicgen.py +308 -0
- audiocraft/models/unet.py +214 -0
audiocraft/models/audiogen.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Main model for using AudioGen. This will combine all the required components
|
9 |
+
and provide easy access to the generation API.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from .encodec import CompressionModel
|
17 |
+
from .genmodel import BaseGenModel
|
18 |
+
from .lm import LMModel
|
19 |
+
from .builders import get_debug_compression_model, get_debug_lm_model
|
20 |
+
from .loaders import load_compression_model, load_lm_model
|
21 |
+
|
22 |
+
|
23 |
+
class AudioGen(BaseGenModel):
|
24 |
+
"""AudioGen main model with convenient generation API.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
name (str): name of the model.
|
28 |
+
compression_model (CompressionModel): Compression model
|
29 |
+
used to map audio to invertible discrete representations.
|
30 |
+
lm (LMModel): Language model over discrete representations.
|
31 |
+
max_duration (float, optional): maximum duration the model can produce,
|
32 |
+
otherwise, inferred from the training params.
|
33 |
+
"""
|
34 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
35 |
+
max_duration: tp.Optional[float] = None):
|
36 |
+
super().__init__(name, compression_model, lm, max_duration)
|
37 |
+
self.set_generation_params(duration=5) # default duration
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def get_pretrained(name: str = 'facebook/audiogen-medium', device=None):
|
41 |
+
"""Return pretrained model, we provide a single model for now:
|
42 |
+
- facebook/audiogen-medium (1.5B), text to sound,
|
43 |
+
# see: https://huggingface.co/facebook/audiogen-medium
|
44 |
+
"""
|
45 |
+
if device is None:
|
46 |
+
if torch.cuda.device_count():
|
47 |
+
device = 'cuda'
|
48 |
+
else:
|
49 |
+
device = 'cpu'
|
50 |
+
|
51 |
+
if name == 'debug':
|
52 |
+
# used only for unit tests
|
53 |
+
compression_model = get_debug_compression_model(device, sample_rate=16000)
|
54 |
+
lm = get_debug_lm_model(device)
|
55 |
+
return AudioGen(name, compression_model, lm, max_duration=10)
|
56 |
+
|
57 |
+
compression_model = load_compression_model(name, device=device)
|
58 |
+
lm = load_lm_model(name, device=device)
|
59 |
+
assert 'self_wav' not in lm.condition_provider.conditioners, \
|
60 |
+
"AudioGen do not support waveform conditioning for now"
|
61 |
+
return AudioGen(name, compression_model, lm)
|
62 |
+
|
63 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
64 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
65 |
+
duration: float = 10.0, cfg_coef: float = 3.0,
|
66 |
+
two_step_cfg: bool = False, extend_stride: float = 2):
|
67 |
+
"""Set the generation parameters for AudioGen.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
71 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
72 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
73 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
74 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 10.0.
|
75 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
76 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
77 |
+
instead of batching together the two. This has some impact on how things
|
78 |
+
are padded but seems to have little impact in practice.
|
79 |
+
extend_stride: when doing extended generation (i.e. more than 10 seconds), by how much
|
80 |
+
should we extend the audio each time. Larger values will mean less context is
|
81 |
+
preserved, and shorter value will require extra computations.
|
82 |
+
"""
|
83 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
84 |
+
self.extend_stride = extend_stride
|
85 |
+
self.duration = duration
|
86 |
+
self.generation_params = {
|
87 |
+
'use_sampling': use_sampling,
|
88 |
+
'temp': temperature,
|
89 |
+
'top_k': top_k,
|
90 |
+
'top_p': top_p,
|
91 |
+
'cfg_coef': cfg_coef,
|
92 |
+
'two_step_cfg': two_step_cfg,
|
93 |
+
}
|
audiocraft/models/builders.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
All the functions to build the relevant models and modules
|
9 |
+
from the Hydra config.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import audiocraft
|
15 |
+
import omegaconf
|
16 |
+
import torch
|
17 |
+
|
18 |
+
from .encodec import CompressionModel, EncodecModel, InterleaveStereoCompressionModel
|
19 |
+
from .lm import LMModel
|
20 |
+
from .lm_magnet import MagnetLMModel
|
21 |
+
from ..modules.codebooks_patterns import (
|
22 |
+
CodebooksPatternProvider,
|
23 |
+
DelayedPatternProvider,
|
24 |
+
MusicLMPattern,
|
25 |
+
ParallelPatternProvider,
|
26 |
+
UnrolledPatternProvider,
|
27 |
+
CoarseFirstPattern,
|
28 |
+
)
|
29 |
+
from ..modules.conditioners import (
|
30 |
+
BaseConditioner,
|
31 |
+
ChromaStemConditioner,
|
32 |
+
CLAPEmbeddingConditioner,
|
33 |
+
ConditionFuser,
|
34 |
+
ConditioningProvider,
|
35 |
+
LUTConditioner,
|
36 |
+
T5Conditioner,
|
37 |
+
)
|
38 |
+
from .unet import DiffusionUnet
|
39 |
+
from .. import quantization as qt
|
40 |
+
from ..utils.utils import dict_from_config
|
41 |
+
from ..modules.diffusion_schedule import MultiBandProcessor, SampleProcessor
|
42 |
+
|
43 |
+
|
44 |
+
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
45 |
+
klass = {
|
46 |
+
'no_quant': qt.DummyQuantizer,
|
47 |
+
'rvq': qt.ResidualVectorQuantizer
|
48 |
+
}[quantizer]
|
49 |
+
kwargs = dict_from_config(getattr(cfg, quantizer))
|
50 |
+
if quantizer != 'no_quant':
|
51 |
+
kwargs['dimension'] = dimension
|
52 |
+
return klass(**kwargs)
|
53 |
+
|
54 |
+
|
55 |
+
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
56 |
+
if encoder_name == 'seanet':
|
57 |
+
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
58 |
+
encoder_override_kwargs = kwargs.pop('encoder')
|
59 |
+
decoder_override_kwargs = kwargs.pop('decoder')
|
60 |
+
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
61 |
+
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
62 |
+
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
63 |
+
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
64 |
+
return encoder, decoder
|
65 |
+
else:
|
66 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
67 |
+
|
68 |
+
|
69 |
+
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
70 |
+
"""Instantiate a compression model."""
|
71 |
+
if cfg.compression_model == 'encodec':
|
72 |
+
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
73 |
+
encoder_name = kwargs.pop('autoencoder')
|
74 |
+
quantizer_name = kwargs.pop('quantizer')
|
75 |
+
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
76 |
+
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
77 |
+
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
78 |
+
renormalize = kwargs.pop('renormalize', False)
|
79 |
+
# deprecated params
|
80 |
+
kwargs.pop('renorm', None)
|
81 |
+
return EncodecModel(encoder, decoder, quantizer,
|
82 |
+
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
83 |
+
else:
|
84 |
+
raise KeyError(f"Unexpected compression model {cfg.compression_model}")
|
85 |
+
|
86 |
+
|
87 |
+
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
88 |
+
"""Instantiate a transformer LM."""
|
89 |
+
if cfg.lm_model in ['transformer_lm', 'transformer_lm_magnet']:
|
90 |
+
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
91 |
+
n_q = kwargs['n_q']
|
92 |
+
q_modeling = kwargs.pop('q_modeling', None)
|
93 |
+
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
94 |
+
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
95 |
+
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
96 |
+
cfg_prob, cfg_coef = cls_free_guidance['training_dropout'], cls_free_guidance['inference_coef']
|
97 |
+
fuser = get_condition_fuser(cfg)
|
98 |
+
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
99 |
+
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programmatically
|
100 |
+
kwargs['cross_attention'] = True
|
101 |
+
if codebooks_pattern_cfg.modeling is None:
|
102 |
+
assert q_modeling is not None, \
|
103 |
+
"LM model should either have a codebook pattern defined or transformer_lm.q_modeling"
|
104 |
+
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
105 |
+
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
106 |
+
)
|
107 |
+
|
108 |
+
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
109 |
+
lm_class = MagnetLMModel if cfg.lm_model == 'transformer_lm_magnet' else LMModel
|
110 |
+
return lm_class(
|
111 |
+
pattern_provider=pattern_provider,
|
112 |
+
condition_provider=condition_provider,
|
113 |
+
fuser=fuser,
|
114 |
+
cfg_dropout=cfg_prob,
|
115 |
+
cfg_coef=cfg_coef,
|
116 |
+
attribute_dropout=attribute_dropout,
|
117 |
+
dtype=getattr(torch, cfg.dtype),
|
118 |
+
device=cfg.device,
|
119 |
+
**kwargs
|
120 |
+
).to(cfg.device)
|
121 |
+
else:
|
122 |
+
raise KeyError(f"Unexpected LM model {cfg.lm_model}")
|
123 |
+
|
124 |
+
|
125 |
+
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
126 |
+
"""Instantiate a conditioning model."""
|
127 |
+
device = cfg.device
|
128 |
+
duration = cfg.dataset.segment_duration
|
129 |
+
cfg = getattr(cfg, 'conditioners')
|
130 |
+
dict_cfg = {} if cfg is None else dict_from_config(cfg)
|
131 |
+
conditioners: tp.Dict[str, BaseConditioner] = {}
|
132 |
+
condition_provider_args = dict_cfg.pop('args', {})
|
133 |
+
condition_provider_args.pop('merge_text_conditions_p', None)
|
134 |
+
condition_provider_args.pop('drop_desc_p', None)
|
135 |
+
|
136 |
+
for cond, cond_cfg in dict_cfg.items():
|
137 |
+
model_type = cond_cfg['model']
|
138 |
+
model_args = cond_cfg[model_type]
|
139 |
+
if model_type == 't5':
|
140 |
+
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
141 |
+
elif model_type == 'lut':
|
142 |
+
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
143 |
+
elif model_type == 'chroma_stem':
|
144 |
+
conditioners[str(cond)] = ChromaStemConditioner(
|
145 |
+
output_dim=output_dim,
|
146 |
+
duration=duration,
|
147 |
+
device=device,
|
148 |
+
**model_args
|
149 |
+
)
|
150 |
+
elif model_type == 'clap':
|
151 |
+
conditioners[str(cond)] = CLAPEmbeddingConditioner(
|
152 |
+
output_dim=output_dim,
|
153 |
+
device=device,
|
154 |
+
**model_args
|
155 |
+
)
|
156 |
+
else:
|
157 |
+
raise ValueError(f"Unrecognized conditioning model: {model_type}")
|
158 |
+
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
159 |
+
return conditioner
|
160 |
+
|
161 |
+
|
162 |
+
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
163 |
+
"""Instantiate a condition fuser object."""
|
164 |
+
fuser_cfg = getattr(cfg, 'fuser')
|
165 |
+
fuser_methods = ['sum', 'cross', 'prepend', 'input_interpolate']
|
166 |
+
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
167 |
+
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
168 |
+
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
169 |
+
return fuser
|
170 |
+
|
171 |
+
|
172 |
+
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
173 |
+
"""Instantiate a codebooks pattern provider object."""
|
174 |
+
pattern_providers = {
|
175 |
+
'parallel': ParallelPatternProvider,
|
176 |
+
'delay': DelayedPatternProvider,
|
177 |
+
'unroll': UnrolledPatternProvider,
|
178 |
+
'coarse_first': CoarseFirstPattern,
|
179 |
+
'musiclm': MusicLMPattern,
|
180 |
+
}
|
181 |
+
name = cfg.modeling
|
182 |
+
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
183 |
+
klass = pattern_providers[name]
|
184 |
+
return klass(n_q, **kwargs)
|
185 |
+
|
186 |
+
|
187 |
+
def get_debug_compression_model(device='cpu', sample_rate: int = 32000):
|
188 |
+
"""Instantiate a debug compression model to be used for unit tests."""
|
189 |
+
assert sample_rate in [16000, 32000], "unsupported sample rate for debug compression model"
|
190 |
+
model_ratios = {
|
191 |
+
16000: [10, 8, 8], # 25 Hz at 16kHz
|
192 |
+
32000: [10, 8, 16] # 25 Hz at 32kHz
|
193 |
+
}
|
194 |
+
ratios: tp.List[int] = model_ratios[sample_rate]
|
195 |
+
frame_rate = 25
|
196 |
+
seanet_kwargs: dict = {
|
197 |
+
'n_filters': 4,
|
198 |
+
'n_residual_layers': 1,
|
199 |
+
'dimension': 32,
|
200 |
+
'ratios': ratios,
|
201 |
+
}
|
202 |
+
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
203 |
+
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
204 |
+
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
205 |
+
init_x = torch.randn(8, 32, 128)
|
206 |
+
quantizer(init_x, 1) # initialize kmeans etc.
|
207 |
+
compression_model = EncodecModel(
|
208 |
+
encoder, decoder, quantizer,
|
209 |
+
frame_rate=frame_rate, sample_rate=sample_rate, channels=1).to(device)
|
210 |
+
return compression_model.eval()
|
211 |
+
|
212 |
+
|
213 |
+
def get_diffusion_model(cfg: omegaconf.DictConfig):
|
214 |
+
# TODO Find a way to infer the channels from dset
|
215 |
+
channels = cfg.channels
|
216 |
+
num_steps = cfg.schedule.num_steps
|
217 |
+
return DiffusionUnet(
|
218 |
+
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
219 |
+
|
220 |
+
|
221 |
+
def get_processor(cfg, sample_rate: int = 24000):
|
222 |
+
sample_processor = SampleProcessor()
|
223 |
+
if cfg.use:
|
224 |
+
kw = dict(cfg)
|
225 |
+
kw.pop('use')
|
226 |
+
kw.pop('name')
|
227 |
+
if cfg.name == "multi_band_processor":
|
228 |
+
sample_processor = MultiBandProcessor(sample_rate=sample_rate, **kw)
|
229 |
+
return sample_processor
|
230 |
+
|
231 |
+
|
232 |
+
def get_debug_lm_model(device='cpu'):
|
233 |
+
"""Instantiate a debug LM to be used for unit tests."""
|
234 |
+
pattern = DelayedPatternProvider(n_q=4)
|
235 |
+
dim = 16
|
236 |
+
providers = {
|
237 |
+
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
|
238 |
+
}
|
239 |
+
condition_provider = ConditioningProvider(providers)
|
240 |
+
fuser = ConditionFuser(
|
241 |
+
{'cross': ['description'], 'prepend': [],
|
242 |
+
'sum': [], 'input_interpolate': []})
|
243 |
+
lm = LMModel(
|
244 |
+
pattern, condition_provider, fuser,
|
245 |
+
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
246 |
+
cross_attention=True, causal=True)
|
247 |
+
return lm.to(device).eval()
|
248 |
+
|
249 |
+
|
250 |
+
def get_wrapped_compression_model(
|
251 |
+
compression_model: CompressionModel,
|
252 |
+
cfg: omegaconf.DictConfig) -> CompressionModel:
|
253 |
+
if hasattr(cfg, 'interleave_stereo_codebooks'):
|
254 |
+
if cfg.interleave_stereo_codebooks.use:
|
255 |
+
kwargs = dict_from_config(cfg.interleave_stereo_codebooks)
|
256 |
+
kwargs.pop('use')
|
257 |
+
compression_model = InterleaveStereoCompressionModel(compression_model, **kwargs)
|
258 |
+
if hasattr(cfg, 'compression_model_n_q'):
|
259 |
+
if cfg.compression_model_n_q is not None:
|
260 |
+
compression_model.set_num_codebooks(cfg.compression_model_n_q)
|
261 |
+
return compression_model
|
audiocraft/models/encodec.py
ADDED
@@ -0,0 +1,506 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
"""Compression models or wrapper around existing models.
|
7 |
+
Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
8 |
+
"""
|
9 |
+
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
import logging
|
12 |
+
import math
|
13 |
+
from pathlib import Path
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
from einops import rearrange
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
from transformers import EncodecModel as HFEncodecModel
|
21 |
+
|
22 |
+
from .. import quantization as qt
|
23 |
+
|
24 |
+
|
25 |
+
logger = logging.getLogger()
|
26 |
+
|
27 |
+
|
28 |
+
class CompressionModel(ABC, nn.Module):
|
29 |
+
"""Base API for all compression models that aim at being used as audio tokenizers
|
30 |
+
with a language model.
|
31 |
+
"""
|
32 |
+
|
33 |
+
@abstractmethod
|
34 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
35 |
+
...
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
39 |
+
"""See `EncodecModel.encode`."""
|
40 |
+
...
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
44 |
+
"""See `EncodecModel.decode`."""
|
45 |
+
...
|
46 |
+
|
47 |
+
@abstractmethod
|
48 |
+
def decode_latent(self, codes: torch.Tensor):
|
49 |
+
"""Decode from the discrete codes to continuous latent space."""
|
50 |
+
...
|
51 |
+
|
52 |
+
@property
|
53 |
+
@abstractmethod
|
54 |
+
def channels(self) -> int:
|
55 |
+
...
|
56 |
+
|
57 |
+
@property
|
58 |
+
@abstractmethod
|
59 |
+
def frame_rate(self) -> float:
|
60 |
+
...
|
61 |
+
|
62 |
+
@property
|
63 |
+
@abstractmethod
|
64 |
+
def sample_rate(self) -> int:
|
65 |
+
...
|
66 |
+
|
67 |
+
@property
|
68 |
+
@abstractmethod
|
69 |
+
def cardinality(self) -> int:
|
70 |
+
...
|
71 |
+
|
72 |
+
@property
|
73 |
+
@abstractmethod
|
74 |
+
def num_codebooks(self) -> int:
|
75 |
+
...
|
76 |
+
|
77 |
+
@property
|
78 |
+
@abstractmethod
|
79 |
+
def total_codebooks(self) -> int:
|
80 |
+
...
|
81 |
+
|
82 |
+
@abstractmethod
|
83 |
+
def set_num_codebooks(self, n: int):
|
84 |
+
"""Set the active number of codebooks used by the quantizer."""
|
85 |
+
...
|
86 |
+
|
87 |
+
@staticmethod
|
88 |
+
def get_pretrained(
|
89 |
+
name: str, device: tp.Union[torch.device, str] = 'cpu'
|
90 |
+
) -> 'CompressionModel':
|
91 |
+
"""Instantiate a CompressionModel from a given pretrained model.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
name (Path or str): name of the pretrained model. See after.
|
95 |
+
device (torch.device or str): Device on which the model is loaded.
|
96 |
+
|
97 |
+
Pretrained models:
|
98 |
+
- dac_44khz (https://github.com/descriptinc/descript-audio-codec)
|
99 |
+
- dac_24khz (same)
|
100 |
+
- facebook/encodec_24khz (https://huggingface.co/facebook/encodec_24khz)
|
101 |
+
- facebook/encodec_32khz (https://huggingface.co/facebook/encodec_32khz)
|
102 |
+
- your own model on Hugging Face. Export instructions to come...
|
103 |
+
"""
|
104 |
+
|
105 |
+
from . import builders, loaders
|
106 |
+
model: CompressionModel
|
107 |
+
if name in ['dac_44khz', 'dac_24khz']:
|
108 |
+
model_type = name.split('_')[1]
|
109 |
+
logger.info("Getting pretrained compression model from DAC %s", model_type)
|
110 |
+
model = DAC(model_type)
|
111 |
+
elif name in ['debug_compression_model']:
|
112 |
+
logger.info("Getting pretrained compression model for debug")
|
113 |
+
model = builders.get_debug_compression_model()
|
114 |
+
elif Path(name).exists():
|
115 |
+
# We assume here if the path exists that it is in fact an AC checkpoint
|
116 |
+
# that was exported using `audiocraft.utils.export` functions.
|
117 |
+
model = loaders.load_compression_model(name, device=device)
|
118 |
+
else:
|
119 |
+
logger.info("Getting pretrained compression model from HF %s", name)
|
120 |
+
hf_model = HFEncodecModel.from_pretrained(name)
|
121 |
+
model = HFEncodecCompressionModel(hf_model).to(device)
|
122 |
+
return model.to(device).eval()
|
123 |
+
|
124 |
+
|
125 |
+
class EncodecModel(CompressionModel):
|
126 |
+
"""Encodec model operating on the raw waveform.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
encoder (nn.Module): Encoder network.
|
130 |
+
decoder (nn.Module): Decoder network.
|
131 |
+
quantizer (qt.BaseQuantizer): Quantizer network.
|
132 |
+
frame_rate (int): Frame rate for the latent representation.
|
133 |
+
sample_rate (int): Audio sample rate.
|
134 |
+
channels (int): Number of audio channels.
|
135 |
+
causal (bool): Whether to use a causal version of the model.
|
136 |
+
renormalize (bool): Whether to renormalize the audio before running the model.
|
137 |
+
"""
|
138 |
+
# we need assignment to override the property in the abstract class,
|
139 |
+
# I couldn't find a better way...
|
140 |
+
frame_rate: float = 0
|
141 |
+
sample_rate: int = 0
|
142 |
+
channels: int = 0
|
143 |
+
|
144 |
+
def __init__(self,
|
145 |
+
encoder: nn.Module,
|
146 |
+
decoder: nn.Module,
|
147 |
+
quantizer: qt.BaseQuantizer,
|
148 |
+
frame_rate: int,
|
149 |
+
sample_rate: int,
|
150 |
+
channels: int,
|
151 |
+
causal: bool = False,
|
152 |
+
renormalize: bool = False):
|
153 |
+
super().__init__()
|
154 |
+
self.encoder = encoder
|
155 |
+
self.decoder = decoder
|
156 |
+
self.quantizer = quantizer
|
157 |
+
self.frame_rate = frame_rate
|
158 |
+
self.sample_rate = sample_rate
|
159 |
+
self.channels = channels
|
160 |
+
self.renormalize = renormalize
|
161 |
+
self.causal = causal
|
162 |
+
if self.causal:
|
163 |
+
# we force disabling here to avoid handling linear overlap of segments
|
164 |
+
# as supported in original EnCodec codebase.
|
165 |
+
assert not self.renormalize, 'Causal model does not support renormalize'
|
166 |
+
|
167 |
+
@property
|
168 |
+
def total_codebooks(self):
|
169 |
+
"""Total number of quantizer codebooks available."""
|
170 |
+
return self.quantizer.total_codebooks
|
171 |
+
|
172 |
+
@property
|
173 |
+
def num_codebooks(self):
|
174 |
+
"""Active number of codebooks used by the quantizer."""
|
175 |
+
return self.quantizer.num_codebooks
|
176 |
+
|
177 |
+
def set_num_codebooks(self, n: int):
|
178 |
+
"""Set the active number of codebooks used by the quantizer."""
|
179 |
+
self.quantizer.set_num_codebooks(n)
|
180 |
+
|
181 |
+
@property
|
182 |
+
def cardinality(self):
|
183 |
+
"""Cardinality of each codebook."""
|
184 |
+
return self.quantizer.bins
|
185 |
+
|
186 |
+
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
187 |
+
scale: tp.Optional[torch.Tensor]
|
188 |
+
if self.renormalize:
|
189 |
+
mono = x.mean(dim=1, keepdim=True)
|
190 |
+
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
191 |
+
scale = 1e-8 + volume
|
192 |
+
x = x / scale
|
193 |
+
scale = scale.view(-1, 1)
|
194 |
+
else:
|
195 |
+
scale = None
|
196 |
+
return x, scale
|
197 |
+
|
198 |
+
def postprocess(self,
|
199 |
+
x: torch.Tensor,
|
200 |
+
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
201 |
+
if scale is not None:
|
202 |
+
assert self.renormalize
|
203 |
+
x = x * scale.view(-1, 1, 1)
|
204 |
+
return x
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
207 |
+
assert x.dim() == 3
|
208 |
+
length = x.shape[-1]
|
209 |
+
x, scale = self.preprocess(x)
|
210 |
+
|
211 |
+
emb = self.encoder(x)
|
212 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
213 |
+
out = self.decoder(q_res.x)
|
214 |
+
|
215 |
+
# remove extra padding added by the encoder and decoder
|
216 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
217 |
+
out = out[..., :length]
|
218 |
+
|
219 |
+
q_res.x = self.postprocess(out, scale)
|
220 |
+
|
221 |
+
return q_res
|
222 |
+
|
223 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
224 |
+
"""Encode the given input tensor to quantized representation along with scale parameter.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
codes, scale (tuple of torch.Tensor, torch.Tensor): Tuple composed of:
|
231 |
+
codes: a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
232 |
+
scale: a float tensor containing the scale for audio renormalization.
|
233 |
+
"""
|
234 |
+
assert x.dim() == 3
|
235 |
+
x, scale = self.preprocess(x)
|
236 |
+
emb = self.encoder(x)
|
237 |
+
codes = self.quantizer.encode(emb)
|
238 |
+
return codes, scale
|
239 |
+
|
240 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
241 |
+
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
242 |
+
audio denormalization if needed.
|
243 |
+
|
244 |
+
Args:
|
245 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
246 |
+
scale (torch.Tensor, optional): Float tensor containing the scale value.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
250 |
+
"""
|
251 |
+
emb = self.decode_latent(codes)
|
252 |
+
out = self.decoder(emb)
|
253 |
+
out = self.postprocess(out, scale)
|
254 |
+
# out contains extra padding added by the encoder and decoder
|
255 |
+
return out
|
256 |
+
|
257 |
+
def decode_latent(self, codes: torch.Tensor):
|
258 |
+
"""Decode from the discrete codes to continuous latent space."""
|
259 |
+
return self.quantizer.decode(codes)
|
260 |
+
|
261 |
+
|
262 |
+
class DAC(CompressionModel):
|
263 |
+
def __init__(self, model_type: str = "44khz"):
|
264 |
+
super().__init__()
|
265 |
+
try:
|
266 |
+
import dac.utils
|
267 |
+
except ImportError:
|
268 |
+
raise RuntimeError("Could not import dac, make sure it is installed, "
|
269 |
+
"please run `pip install descript-audio-codec`")
|
270 |
+
self.model = dac.utils.load_model(model_type=model_type)
|
271 |
+
self.n_quantizers = self.total_codebooks
|
272 |
+
self.model.eval()
|
273 |
+
|
274 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
275 |
+
# We don't support training with this.
|
276 |
+
raise NotImplementedError("Forward and training with DAC not supported.")
|
277 |
+
|
278 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
279 |
+
codes = self.model.encode(x, self.n_quantizers)[1]
|
280 |
+
return codes[:, :self.n_quantizers], None
|
281 |
+
|
282 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
283 |
+
assert scale is None
|
284 |
+
z_q = self.decode_latent(codes)
|
285 |
+
return self.model.decode(z_q)
|
286 |
+
|
287 |
+
def decode_latent(self, codes: torch.Tensor):
|
288 |
+
"""Decode from the discrete codes to continuous latent space."""
|
289 |
+
return self.model.quantizer.from_codes(codes)[0]
|
290 |
+
|
291 |
+
@property
|
292 |
+
def channels(self) -> int:
|
293 |
+
return 1
|
294 |
+
|
295 |
+
@property
|
296 |
+
def frame_rate(self) -> float:
|
297 |
+
return self.model.sample_rate / self.model.hop_length
|
298 |
+
|
299 |
+
@property
|
300 |
+
def sample_rate(self) -> int:
|
301 |
+
return self.model.sample_rate
|
302 |
+
|
303 |
+
@property
|
304 |
+
def cardinality(self) -> int:
|
305 |
+
return self.model.codebook_size
|
306 |
+
|
307 |
+
@property
|
308 |
+
def num_codebooks(self) -> int:
|
309 |
+
return self.n_quantizers
|
310 |
+
|
311 |
+
@property
|
312 |
+
def total_codebooks(self) -> int:
|
313 |
+
return self.model.n_codebooks
|
314 |
+
|
315 |
+
def set_num_codebooks(self, n: int):
|
316 |
+
"""Set the active number of codebooks used by the quantizer.
|
317 |
+
"""
|
318 |
+
assert n >= 1
|
319 |
+
assert n <= self.total_codebooks
|
320 |
+
self.n_quantizers = n
|
321 |
+
|
322 |
+
|
323 |
+
class HFEncodecCompressionModel(CompressionModel):
|
324 |
+
"""Wrapper around HuggingFace Encodec.
|
325 |
+
"""
|
326 |
+
def __init__(self, model: HFEncodecModel):
|
327 |
+
super().__init__()
|
328 |
+
self.model = model
|
329 |
+
bws = self.model.config.target_bandwidths
|
330 |
+
num_codebooks = [
|
331 |
+
bw * 1000 / (self.frame_rate * math.log2(self.cardinality))
|
332 |
+
for bw in bws
|
333 |
+
]
|
334 |
+
deltas = [nc - int(nc) for nc in num_codebooks]
|
335 |
+
# Checking we didn't do some bad maths and we indeed have integers!
|
336 |
+
assert all(deltas) <= 1e-3, deltas
|
337 |
+
self.possible_num_codebooks = [int(nc) for nc in num_codebooks]
|
338 |
+
self.set_num_codebooks(max(self.possible_num_codebooks))
|
339 |
+
|
340 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
341 |
+
# We don't support training with this.
|
342 |
+
raise NotImplementedError("Forward and training with HF EncodecModel not supported.")
|
343 |
+
|
344 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
345 |
+
bandwidth_index = self.possible_num_codebooks.index(self.num_codebooks)
|
346 |
+
bandwidth = self.model.config.target_bandwidths[bandwidth_index]
|
347 |
+
res = self.model.encode(x, None, bandwidth)
|
348 |
+
assert len(res[0]) == 1
|
349 |
+
assert len(res[1]) == 1
|
350 |
+
return res[0][0], res[1][0]
|
351 |
+
|
352 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
353 |
+
if scale is None:
|
354 |
+
scales = [None] # type: ignore
|
355 |
+
else:
|
356 |
+
scales = scale # type: ignore
|
357 |
+
res = self.model.decode(codes[None], scales)
|
358 |
+
return res[0]
|
359 |
+
|
360 |
+
def decode_latent(self, codes: torch.Tensor):
|
361 |
+
"""Decode from the discrete codes to continuous latent space."""
|
362 |
+
return self.model.quantizer.decode(codes.transpose(0, 1))
|
363 |
+
|
364 |
+
@property
|
365 |
+
def channels(self) -> int:
|
366 |
+
return self.model.config.audio_channels
|
367 |
+
|
368 |
+
@property
|
369 |
+
def frame_rate(self) -> float:
|
370 |
+
hop_length = int(np.prod(self.model.config.upsampling_ratios))
|
371 |
+
return self.sample_rate / hop_length
|
372 |
+
|
373 |
+
@property
|
374 |
+
def sample_rate(self) -> int:
|
375 |
+
return self.model.config.sampling_rate
|
376 |
+
|
377 |
+
@property
|
378 |
+
def cardinality(self) -> int:
|
379 |
+
return self.model.config.codebook_size
|
380 |
+
|
381 |
+
@property
|
382 |
+
def num_codebooks(self) -> int:
|
383 |
+
return self._num_codebooks
|
384 |
+
|
385 |
+
@property
|
386 |
+
def total_codebooks(self) -> int:
|
387 |
+
return max(self.possible_num_codebooks)
|
388 |
+
|
389 |
+
def set_num_codebooks(self, n: int):
|
390 |
+
"""Set the active number of codebooks used by the quantizer.
|
391 |
+
"""
|
392 |
+
if n not in self.possible_num_codebooks:
|
393 |
+
raise ValueError(f"Allowed values for num codebooks: {self.possible_num_codebooks}")
|
394 |
+
self._num_codebooks = n
|
395 |
+
|
396 |
+
|
397 |
+
class InterleaveStereoCompressionModel(CompressionModel):
|
398 |
+
"""Wraps a CompressionModel to support stereo inputs. The wrapped model
|
399 |
+
will be applied independently to the left and right channels, and both codebooks
|
400 |
+
will be interleaved. If the wrapped model returns a representation `[B, K ,T]` per
|
401 |
+
channel, then the output will be `[B, K * 2, T]` or `[B, K, T * 2]` depending on
|
402 |
+
`per_timestep`.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
model (CompressionModel): Compression model to wrap.
|
406 |
+
per_timestep (bool): Whether to interleave on the timestep dimension
|
407 |
+
or on the codebooks dimension.
|
408 |
+
"""
|
409 |
+
def __init__(self, model: CompressionModel, per_timestep: bool = False):
|
410 |
+
super().__init__()
|
411 |
+
self.model = model
|
412 |
+
self.per_timestep = per_timestep
|
413 |
+
assert self.model.channels == 1, "Wrapped model is expected to be for monophonic audio"
|
414 |
+
|
415 |
+
@property
|
416 |
+
def total_codebooks(self):
|
417 |
+
return self.model.total_codebooks
|
418 |
+
|
419 |
+
@property
|
420 |
+
def num_codebooks(self):
|
421 |
+
"""Active number of codebooks used by the quantizer.
|
422 |
+
|
423 |
+
..Warning:: this reports the number of codebooks after the interleaving
|
424 |
+
of the codebooks!
|
425 |
+
"""
|
426 |
+
return self.model.num_codebooks if self.per_timestep else self.model.num_codebooks * 2
|
427 |
+
|
428 |
+
def set_num_codebooks(self, n: int):
|
429 |
+
"""Set the active number of codebooks used by the quantizer.
|
430 |
+
|
431 |
+
..Warning:: this sets the number of codebooks before the interleaving!
|
432 |
+
"""
|
433 |
+
self.model.set_num_codebooks(n)
|
434 |
+
|
435 |
+
@property
|
436 |
+
def num_virtual_steps(self) -> float:
|
437 |
+
"""Return the number of virtual steps, e.g. one real step
|
438 |
+
will be split into that many steps.
|
439 |
+
"""
|
440 |
+
return 2 if self.per_timestep else 1
|
441 |
+
|
442 |
+
@property
|
443 |
+
def frame_rate(self) -> float:
|
444 |
+
return self.model.frame_rate * self.num_virtual_steps
|
445 |
+
|
446 |
+
@property
|
447 |
+
def sample_rate(self) -> int:
|
448 |
+
return self.model.sample_rate
|
449 |
+
|
450 |
+
@property
|
451 |
+
def channels(self) -> int:
|
452 |
+
return 2
|
453 |
+
|
454 |
+
@property
|
455 |
+
def cardinality(self):
|
456 |
+
"""Cardinality of each codebook.
|
457 |
+
"""
|
458 |
+
return self.model.cardinality
|
459 |
+
|
460 |
+
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
461 |
+
raise NotImplementedError("Not supported, use encode and decode.")
|
462 |
+
|
463 |
+
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
464 |
+
B, C, T = x.shape
|
465 |
+
assert C == self.channels, f"Expecting stereo audio but audio num channels is {C}"
|
466 |
+
|
467 |
+
indices_c0, scales_c0 = self.model.encode(x[:, 0, ...].unsqueeze(1))
|
468 |
+
indices_c1, scales_c1 = self.model.encode(x[:, 1, ...].unsqueeze(1))
|
469 |
+
indices = torch.stack([indices_c0, indices_c1], dim=0)
|
470 |
+
scales: tp.Optional[torch.Tensor] = None
|
471 |
+
if scales_c0 is not None and scales_c1 is not None:
|
472 |
+
scales = torch.stack([scales_c0, scales_c1], dim=1)
|
473 |
+
|
474 |
+
if self.per_timestep:
|
475 |
+
indices = rearrange(indices, 'c b k t -> b k (t c)', c=2)
|
476 |
+
else:
|
477 |
+
indices = rearrange(indices, 'c b k t -> b (k c) t', c=2)
|
478 |
+
|
479 |
+
return (indices, scales)
|
480 |
+
|
481 |
+
def get_left_right_codes(self, codes: torch.Tensor) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
482 |
+
if self.per_timestep:
|
483 |
+
codes = rearrange(codes, 'b k (t c) -> c b k t', c=2)
|
484 |
+
else:
|
485 |
+
codes = rearrange(codes, 'b (k c) t -> c b k t', c=2)
|
486 |
+
return codes[0], codes[1]
|
487 |
+
|
488 |
+
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
489 |
+
B, K, T = codes.shape
|
490 |
+
assert T % self.num_virtual_steps == 0, "Provided codes' number of timesteps does not match"
|
491 |
+
assert K == self.num_codebooks, "Provided codes' number of codebooks does not match"
|
492 |
+
|
493 |
+
scale_c0, scale_c1 = None, None
|
494 |
+
if scale is not None:
|
495 |
+
assert scale.size(0) == B and scale.size(1) == 2, f"Scale has unexpected shape: {scale.shape}"
|
496 |
+
scale_c0 = scale[0, ...]
|
497 |
+
scale_c1 = scale[1, ...]
|
498 |
+
|
499 |
+
codes_c0, codes_c1 = self.get_left_right_codes(codes)
|
500 |
+
audio_c0 = self.model.decode(codes_c0, scale_c0)
|
501 |
+
audio_c1 = self.model.decode(codes_c1, scale_c1)
|
502 |
+
return torch.cat([audio_c0, audio_c1], dim=1)
|
503 |
+
|
504 |
+
def decode_latent(self, codes: torch.Tensor):
|
505 |
+
"""Decode from the discrete codes to continuous latent space."""
|
506 |
+
raise NotImplementedError("Not supported by interleaved stereo wrapped models.")
|
audiocraft/models/genmodel.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Base implementation for audio generative models. This base implementation
|
9 |
+
combines all the required components to run inference with pretrained audio
|
10 |
+
generative models. It can be easily inherited by downstream model classes to
|
11 |
+
provide easy access to the generation API.
|
12 |
+
"""
|
13 |
+
|
14 |
+
from abc import ABC, abstractmethod
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import omegaconf
|
18 |
+
import torch
|
19 |
+
|
20 |
+
from .encodec import CompressionModel
|
21 |
+
from .lm import LMModel
|
22 |
+
from .builders import get_wrapped_compression_model
|
23 |
+
from ..data.audio_utils import convert_audio
|
24 |
+
from ..modules.conditioners import ConditioningAttributes
|
25 |
+
from ..utils.autocast import TorchAutocast
|
26 |
+
|
27 |
+
|
28 |
+
class BaseGenModel(ABC):
|
29 |
+
"""Base generative model with convenient generation API.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
name (str): name of the model.
|
33 |
+
compression_model (CompressionModel): Compression model
|
34 |
+
used to map audio to invertible discrete representations.
|
35 |
+
lm (LMModel): Language model over discrete representations.
|
36 |
+
max_duration (float, optional): maximum duration the model can produce,
|
37 |
+
otherwise, inferred from the training params.
|
38 |
+
"""
|
39 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
40 |
+
max_duration: tp.Optional[float] = None):
|
41 |
+
self.name = name
|
42 |
+
self.compression_model = compression_model
|
43 |
+
self.lm = lm
|
44 |
+
self.cfg: tp.Optional[omegaconf.DictConfig] = None
|
45 |
+
# Just to be safe, let's put everything in eval mode.
|
46 |
+
self.compression_model.eval()
|
47 |
+
self.lm.eval()
|
48 |
+
|
49 |
+
if hasattr(lm, 'cfg'):
|
50 |
+
cfg = lm.cfg
|
51 |
+
assert isinstance(cfg, omegaconf.DictConfig)
|
52 |
+
self.cfg = cfg
|
53 |
+
|
54 |
+
if self.cfg is not None:
|
55 |
+
self.compression_model = get_wrapped_compression_model(self.compression_model, self.cfg)
|
56 |
+
|
57 |
+
if max_duration is None:
|
58 |
+
if self.cfg is not None:
|
59 |
+
max_duration = lm.cfg.dataset.segment_duration # type: ignore
|
60 |
+
else:
|
61 |
+
raise ValueError("You must provide max_duration when building directly your GenModel")
|
62 |
+
assert max_duration is not None
|
63 |
+
|
64 |
+
self.max_duration: float = max_duration
|
65 |
+
self.duration = self.max_duration
|
66 |
+
|
67 |
+
# self.extend_stride is the length of audio extension when generating samples longer
|
68 |
+
# than self.max_duration. NOTE: the derived class must set self.extend_stride to a
|
69 |
+
# positive float value when generating with self.duration > self.max_duration.
|
70 |
+
self.extend_stride: tp.Optional[float] = None
|
71 |
+
self.device = next(iter(lm.parameters())).device
|
72 |
+
self.generation_params: dict = {}
|
73 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
74 |
+
if self.device.type == 'cpu':
|
75 |
+
self.autocast = TorchAutocast(enabled=False)
|
76 |
+
else:
|
77 |
+
self.autocast = TorchAutocast(
|
78 |
+
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
79 |
+
|
80 |
+
@property
|
81 |
+
def frame_rate(self) -> float:
|
82 |
+
"""Roughly the number of AR steps per seconds."""
|
83 |
+
return self.compression_model.frame_rate
|
84 |
+
|
85 |
+
@property
|
86 |
+
def sample_rate(self) -> int:
|
87 |
+
"""Sample rate of the generated audio."""
|
88 |
+
return self.compression_model.sample_rate
|
89 |
+
|
90 |
+
@property
|
91 |
+
def audio_channels(self) -> int:
|
92 |
+
"""Audio channels of the generated audio."""
|
93 |
+
return self.compression_model.channels
|
94 |
+
|
95 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
96 |
+
"""Override the default progress callback."""
|
97 |
+
self._progress_callback = progress_callback
|
98 |
+
|
99 |
+
@abstractmethod
|
100 |
+
def set_generation_params(self, *args, **kwargs):
|
101 |
+
"""Set the generation parameters."""
|
102 |
+
raise NotImplementedError("No base implementation for setting generation params.")
|
103 |
+
|
104 |
+
@staticmethod
|
105 |
+
@abstractmethod
|
106 |
+
def get_pretrained(name: str, device=None):
|
107 |
+
raise NotImplementedError("No base implementation for getting pretrained model")
|
108 |
+
|
109 |
+
@torch.no_grad()
|
110 |
+
def _prepare_tokens_and_attributes(
|
111 |
+
self,
|
112 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
113 |
+
prompt: tp.Optional[torch.Tensor],
|
114 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
115 |
+
"""Prepare model inputs.
|
116 |
+
|
117 |
+
Args:
|
118 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
119 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
120 |
+
"""
|
121 |
+
attributes = [
|
122 |
+
ConditioningAttributes(text={'description': description})
|
123 |
+
for description in descriptions]
|
124 |
+
|
125 |
+
if prompt is not None:
|
126 |
+
if descriptions is not None:
|
127 |
+
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
128 |
+
prompt = prompt.to(self.device)
|
129 |
+
prompt_tokens, scale = self.compression_model.encode(prompt)
|
130 |
+
assert scale is None
|
131 |
+
else:
|
132 |
+
prompt_tokens = None
|
133 |
+
return attributes, prompt_tokens
|
134 |
+
|
135 |
+
def generate_unconditional(self, num_samples: int, progress: bool = False,
|
136 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
137 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
138 |
+
"""Generate samples in an unconditional manner.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
num_samples (int): Number of samples to be generated.
|
142 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
143 |
+
"""
|
144 |
+
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
145 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
146 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
147 |
+
if return_tokens:
|
148 |
+
return self.generate_audio(tokens), tokens
|
149 |
+
return self.generate_audio(tokens)
|
150 |
+
|
151 |
+
def generate(self, descriptions: tp.List[str], progress: bool = False, return_tokens: bool = False) \
|
152 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
153 |
+
"""Generate samples conditioned on text.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
157 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
158 |
+
"""
|
159 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
160 |
+
assert prompt_tokens is None
|
161 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
162 |
+
if return_tokens:
|
163 |
+
return self.generate_audio(tokens), tokens
|
164 |
+
return self.generate_audio(tokens)
|
165 |
+
|
166 |
+
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
167 |
+
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
168 |
+
progress: bool = False, return_tokens: bool = False) \
|
169 |
+
-> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, torch.Tensor]]:
|
170 |
+
"""Generate samples conditioned on audio prompts and an optional text description.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
174 |
+
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
175 |
+
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
176 |
+
descriptions (list of str, optional): A list of strings used as text conditioning. Defaults to None.
|
177 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
178 |
+
"""
|
179 |
+
if prompt.dim() == 2:
|
180 |
+
prompt = prompt[None]
|
181 |
+
if prompt.dim() != 3:
|
182 |
+
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
183 |
+
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
184 |
+
if descriptions is None:
|
185 |
+
descriptions = [None] * len(prompt)
|
186 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
187 |
+
assert prompt_tokens is not None
|
188 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
189 |
+
if return_tokens:
|
190 |
+
return self.generate_audio(tokens), tokens
|
191 |
+
return self.generate_audio(tokens)
|
192 |
+
|
193 |
+
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
194 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
195 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (here text).
|
199 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
200 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
201 |
+
Returns:
|
202 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
203 |
+
"""
|
204 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
205 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
206 |
+
current_gen_offset: int = 0
|
207 |
+
|
208 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
209 |
+
generated_tokens += current_gen_offset
|
210 |
+
if self._progress_callback is not None:
|
211 |
+
# Note that total_gen_len might be quite wrong depending on the
|
212 |
+
# codebook pattern used, but with delay it is almost accurate.
|
213 |
+
self._progress_callback(generated_tokens, tokens_to_generate)
|
214 |
+
else:
|
215 |
+
print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
216 |
+
|
217 |
+
if prompt_tokens is not None:
|
218 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
219 |
+
"Prompt is longer than audio to generate"
|
220 |
+
|
221 |
+
callback = None
|
222 |
+
if progress:
|
223 |
+
callback = _progress_callback
|
224 |
+
|
225 |
+
if self.duration <= self.max_duration:
|
226 |
+
# generate by sampling from LM, simple case.
|
227 |
+
with self.autocast:
|
228 |
+
gen_tokens = self.lm.generate(
|
229 |
+
prompt_tokens, attributes,
|
230 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
231 |
+
|
232 |
+
else:
|
233 |
+
assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
|
234 |
+
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
235 |
+
all_tokens = []
|
236 |
+
if prompt_tokens is None:
|
237 |
+
prompt_length = 0
|
238 |
+
else:
|
239 |
+
all_tokens.append(prompt_tokens)
|
240 |
+
prompt_length = prompt_tokens.shape[-1]
|
241 |
+
|
242 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
243 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
244 |
+
time_offset = current_gen_offset / self.frame_rate
|
245 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
246 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
247 |
+
with self.autocast:
|
248 |
+
gen_tokens = self.lm.generate(
|
249 |
+
prompt_tokens, attributes,
|
250 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
251 |
+
if prompt_tokens is None:
|
252 |
+
all_tokens.append(gen_tokens)
|
253 |
+
else:
|
254 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
255 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
256 |
+
prompt_length = prompt_tokens.shape[-1]
|
257 |
+
current_gen_offset += stride_tokens
|
258 |
+
|
259 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
260 |
+
return gen_tokens
|
261 |
+
|
262 |
+
def generate_audio(self, gen_tokens: torch.Tensor) -> torch.Tensor:
|
263 |
+
"""Generate Audio from tokens."""
|
264 |
+
assert gen_tokens.dim() == 3
|
265 |
+
with torch.no_grad():
|
266 |
+
gen_audio = self.compression_model.decode(gen_tokens, None)
|
267 |
+
return gen_audio
|
audiocraft/models/lm.py
ADDED
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from dataclasses import dataclass
|
8 |
+
from functools import partial
|
9 |
+
import logging
|
10 |
+
import math
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
from ..utils import utils
|
17 |
+
from ..modules.streaming import StreamingModule, State
|
18 |
+
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
19 |
+
from ..modules.conditioners import (
|
20 |
+
ConditionFuser,
|
21 |
+
ClassifierFreeGuidanceDropout,
|
22 |
+
AttributeDropout,
|
23 |
+
ConditioningProvider,
|
24 |
+
ConditioningAttributes,
|
25 |
+
ConditionType,
|
26 |
+
)
|
27 |
+
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
28 |
+
from ..modules.activations import get_activation_fn
|
29 |
+
|
30 |
+
|
31 |
+
logger = logging.getLogger(__name__)
|
32 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
33 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
34 |
+
|
35 |
+
|
36 |
+
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
37 |
+
"""LM layer initialization.
|
38 |
+
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
39 |
+
|
40 |
+
Args:
|
41 |
+
method (str): Method name for init function. Valid options are:
|
42 |
+
'gaussian', 'uniform'.
|
43 |
+
input_dim (int): Input dimension of the initialized module.
|
44 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
45 |
+
the standard deviation if defined.
|
46 |
+
"""
|
47 |
+
# Compute std
|
48 |
+
std = 1 / math.sqrt(input_dim)
|
49 |
+
# Rescale with depth
|
50 |
+
if init_depth is not None:
|
51 |
+
std = std / math.sqrt(2 * init_depth)
|
52 |
+
|
53 |
+
if method == 'gaussian':
|
54 |
+
return partial(
|
55 |
+
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
56 |
+
)
|
57 |
+
elif method == 'uniform':
|
58 |
+
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
59 |
+
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
60 |
+
else:
|
61 |
+
raise ValueError("Unsupported layer initialization method")
|
62 |
+
|
63 |
+
|
64 |
+
def init_layer(m: nn.Module,
|
65 |
+
method: str,
|
66 |
+
init_depth: tp.Optional[int] = None,
|
67 |
+
zero_bias_init: bool = False):
|
68 |
+
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
m (nn.Module): Module to initialize.
|
72 |
+
method (str): Method name for the init function.
|
73 |
+
init_depth (int, optional): Optional init depth value used to rescale
|
74 |
+
the standard deviation if defined.
|
75 |
+
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
76 |
+
"""
|
77 |
+
if isinstance(m, nn.Linear):
|
78 |
+
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
79 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
80 |
+
weight = m.weight.float()
|
81 |
+
init_fn(weight)
|
82 |
+
m.weight.data[:] = weight.half()
|
83 |
+
else:
|
84 |
+
init_fn(m.weight)
|
85 |
+
if zero_bias_init and m.bias is not None:
|
86 |
+
nn.init.constant_(m.bias, 0)
|
87 |
+
elif isinstance(m, nn.Embedding):
|
88 |
+
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
89 |
+
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
90 |
+
weight = m.weight.float()
|
91 |
+
init_fn(weight)
|
92 |
+
m.weight.data[:] = weight.half()
|
93 |
+
else:
|
94 |
+
init_fn(m.weight)
|
95 |
+
|
96 |
+
|
97 |
+
class ScaledEmbedding(nn.Embedding):
|
98 |
+
"""Boost learning rate for embeddings (with `scale`).
|
99 |
+
"""
|
100 |
+
def __init__(self, *args, lr=None, **kwargs):
|
101 |
+
super().__init__(*args, **kwargs)
|
102 |
+
self.lr = lr
|
103 |
+
|
104 |
+
def make_optim_group(self):
|
105 |
+
group = {"params": list(self.parameters())}
|
106 |
+
if self.lr is not None:
|
107 |
+
group["lr"] = self.lr
|
108 |
+
return group
|
109 |
+
|
110 |
+
|
111 |
+
@dataclass
|
112 |
+
class LMOutput:
|
113 |
+
# The logits are already re-aligned with the input codes
|
114 |
+
# hence no extra shift is required, e.g. when computing CE
|
115 |
+
logits: torch.Tensor # [B, K, T, card]
|
116 |
+
mask: torch.Tensor # [B, K, T]
|
117 |
+
|
118 |
+
|
119 |
+
class LMModel(StreamingModule):
|
120 |
+
"""Transformer-based language model on multiple streams of codes.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
124 |
+
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
125 |
+
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
126 |
+
n_q (int): Number of parallel streams to model.
|
127 |
+
card (int): Cardinality, vocabulary size.
|
128 |
+
dim (int): Dimension of the transformer encoder.
|
129 |
+
num_heads (int): Number of heads for the transformer encoder.
|
130 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
131 |
+
norm (str): Normalization method.
|
132 |
+
norm_first (bool): Use pre-norm instead of post-norm.
|
133 |
+
emb_lr (float, optional): Embedding-specific learning rate.
|
134 |
+
bias_proj (bool): Use bias for output projections.
|
135 |
+
weight_init (str, optional): Method for weight initialization.
|
136 |
+
depthwise_init (str, optional): Method for depthwise weight initialization.
|
137 |
+
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
138 |
+
cfg_dropout (float): Classifier-free guidance dropout.
|
139 |
+
cfg_coef (float): Classifier-free guidance coefficient.
|
140 |
+
attribute_dropout (dict): Attribute dropout probabilities.
|
141 |
+
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
142 |
+
**kwargs: Additional parameters for the transformer encoder.
|
143 |
+
"""
|
144 |
+
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
|
145 |
+
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
146 |
+
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
147 |
+
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
148 |
+
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
149 |
+
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
150 |
+
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
|
151 |
+
**kwargs):
|
152 |
+
super().__init__()
|
153 |
+
self.cfg_coef = cfg_coef
|
154 |
+
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
155 |
+
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
156 |
+
self.condition_provider = condition_provider
|
157 |
+
self.fuser = fuser
|
158 |
+
self.card = card
|
159 |
+
embed_dim = self.card + 1
|
160 |
+
self.n_q = n_q
|
161 |
+
self.dim = dim
|
162 |
+
self.pattern_provider = pattern_provider
|
163 |
+
self.two_step_cfg = two_step_cfg
|
164 |
+
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
165 |
+
if 'activation' in kwargs:
|
166 |
+
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
167 |
+
self.transformer = StreamingTransformer(
|
168 |
+
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
169 |
+
norm=norm, norm_first=norm_first, **kwargs)
|
170 |
+
self.out_norm: tp.Optional[nn.Module] = None
|
171 |
+
if norm_first:
|
172 |
+
self.out_norm = create_norm_fn(norm, dim)
|
173 |
+
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
|
174 |
+
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
175 |
+
self._fsdp: tp.Optional[nn.Module]
|
176 |
+
self.__dict__['_fsdp'] = None
|
177 |
+
|
178 |
+
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
179 |
+
"""Initialization of the transformer module weights.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
183 |
+
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid:
|
184 |
+
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
185 |
+
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
186 |
+
zero_bias_init (bool): Whether to initialize bias to zero or not.
|
187 |
+
"""
|
188 |
+
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
189 |
+
assert depthwise_init is None or weight_init is not None, \
|
190 |
+
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
191 |
+
assert not zero_bias_init or weight_init is not None, \
|
192 |
+
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
193 |
+
|
194 |
+
if weight_init is None:
|
195 |
+
return
|
196 |
+
|
197 |
+
for emb_layer in self.emb:
|
198 |
+
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
199 |
+
|
200 |
+
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
201 |
+
depth = None
|
202 |
+
if depthwise_init == 'current':
|
203 |
+
depth = layer_idx + 1
|
204 |
+
elif depthwise_init == 'global':
|
205 |
+
depth = len(self.transformer.layers)
|
206 |
+
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
207 |
+
tr_layer.apply(init_fn)
|
208 |
+
|
209 |
+
for linear in self.linears:
|
210 |
+
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
211 |
+
|
212 |
+
@property
|
213 |
+
def special_token_id(self) -> int:
|
214 |
+
return self.card
|
215 |
+
|
216 |
+
@property
|
217 |
+
def num_codebooks(self) -> int:
|
218 |
+
return self.n_q
|
219 |
+
|
220 |
+
def forward(self, sequence: torch.Tensor,
|
221 |
+
conditions: tp.List[ConditioningAttributes],
|
222 |
+
condition_tensors: tp.Optional[ConditionTensors] = None,
|
223 |
+
stage: int = -1) -> torch.Tensor:
|
224 |
+
"""Apply language model on sequence and conditions.
|
225 |
+
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
226 |
+
S the sequence steps, return the logits with shape [B, card, K, S].
|
227 |
+
|
228 |
+
Args:
|
229 |
+
indices (torch.Tensor): Indices of the codes to model.
|
230 |
+
conditions (list of ConditioningAttributes): Conditions to use when modeling
|
231 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
232 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
233 |
+
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning
|
234 |
+
tensors, see `conditions`.
|
235 |
+
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
236 |
+
in which prediction is done in a codebook-by-codebook manner.
|
237 |
+
Takes values in range(n_q), and ignored by default.
|
238 |
+
Returns:
|
239 |
+
torch.Tensor: Logits.
|
240 |
+
"""
|
241 |
+
B, K, S = sequence.shape
|
242 |
+
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks"
|
243 |
+
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
244 |
+
if condition_tensors is None:
|
245 |
+
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
246 |
+
# apply dropout modules
|
247 |
+
conditions = self.cfg_dropout(conditions)
|
248 |
+
conditions = self.att_dropout(conditions)
|
249 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
250 |
+
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
|
251 |
+
condition_tensors = self.condition_provider(tokenized)
|
252 |
+
else:
|
253 |
+
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
|
254 |
+
|
255 |
+
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
256 |
+
|
257 |
+
out = self.transformer(input_, cross_attention_src=cross_attention_input,
|
258 |
+
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None))
|
259 |
+
if self.out_norm:
|
260 |
+
out = self.out_norm(out)
|
261 |
+
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
262 |
+
|
263 |
+
# remove the prefix from the model outputs
|
264 |
+
if len(self.fuser.fuse2cond['prepend']) > 0:
|
265 |
+
logits = logits[:, :, -S:]
|
266 |
+
|
267 |
+
return logits # [B, K, S, card]
|
268 |
+
|
269 |
+
def compute_predictions(
|
270 |
+
self, codes: torch.Tensor,
|
271 |
+
conditions: tp.List[ConditioningAttributes],
|
272 |
+
condition_tensors: tp.Optional[ConditionTensors] = None,
|
273 |
+
stage: int = -1,
|
274 |
+
keep_only_valid_steps: bool = True) -> LMOutput:
|
275 |
+
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
276 |
+
forward using the specified codes interleaving pattern.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
280 |
+
K the number of codebooks and T the number of timesteps.
|
281 |
+
conditions (list of ConditioningAttributes): conditionings to use when modeling
|
282 |
+
the given codes. Note that when evaluating multiple time with the same conditioning
|
283 |
+
you should pre-compute those and pass them as `condition_tensors`.
|
284 |
+
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning
|
285 |
+
tensors, see `conditions`.
|
286 |
+
stage (int): The codebook level that is being predicted. Relevant for MAGNeT
|
287 |
+
in which prediction is done in a codebook-by-codebook manner.
|
288 |
+
Takes values in range(n_q), and ignored by default.
|
289 |
+
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
290 |
+
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
291 |
+
Returns:
|
292 |
+
LMOutput: Language model outputs
|
293 |
+
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
294 |
+
i.e. the first item corresponds to logits to predict the first code, meaning that
|
295 |
+
no additional shifting of codes and logits is required.
|
296 |
+
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
297 |
+
Given the specified interleaving strategies, parts of the logits and codes should
|
298 |
+
not be considered as valid predictions because of invalid context.
|
299 |
+
"""
|
300 |
+
B, K, T = codes.shape
|
301 |
+
codes = codes.contiguous()
|
302 |
+
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
303 |
+
pattern = self.pattern_provider.get_pattern(T)
|
304 |
+
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
305 |
+
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps,
|
306 |
+
)
|
307 |
+
|
308 |
+
# apply model on pattern sequence
|
309 |
+
model = self if self._fsdp is None else self._fsdp
|
310 |
+
logits = model(sequence_codes, conditions, condition_tensors, stage=stage) # [B, K, S, card]
|
311 |
+
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
312 |
+
# and provide the corresponding mask over invalid positions of tokens
|
313 |
+
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
314 |
+
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
315 |
+
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
316 |
+
logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps
|
317 |
+
)
|
318 |
+
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
319 |
+
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
320 |
+
return LMOutput(logits, logits_mask)
|
321 |
+
|
322 |
+
def _sample_next_token(self,
|
323 |
+
sequence: torch.Tensor,
|
324 |
+
cfg_conditions: CFGConditions,
|
325 |
+
unconditional_state: State,
|
326 |
+
use_sampling: bool = False,
|
327 |
+
temp: float = 1.0,
|
328 |
+
top_k: int = 0,
|
329 |
+
top_p: float = 0.0,
|
330 |
+
cfg_coef: tp.Optional[float] = None,
|
331 |
+
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor:
|
332 |
+
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
333 |
+
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
334 |
+
|
335 |
+
Args:
|
336 |
+
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
337 |
+
with K corresponding to the number of codebooks and S the number of sequence steps.
|
338 |
+
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
339 |
+
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used,
|
340 |
+
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
341 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
342 |
+
temp (float): Sampling temperature.
|
343 |
+
top_k (int): K for "top-k" sampling.
|
344 |
+
top_p (float): P for "top-p" sampling.
|
345 |
+
cfg_coef (float, optional): classifier free guidance coefficient
|
346 |
+
Returns:
|
347 |
+
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
348 |
+
"""
|
349 |
+
B = sequence.shape[0]
|
350 |
+
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
351 |
+
model = self if self._fsdp is None else self._fsdp
|
352 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
353 |
+
if two_step_cfg and cfg_conditions != {}:
|
354 |
+
assert isinstance(cfg_conditions, tuple), type(cfg_conditions)
|
355 |
+
condition_tensors, null_condition_tensors = cfg_conditions
|
356 |
+
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
357 |
+
state = self.get_streaming_state()
|
358 |
+
self.set_streaming_state(unconditional_state)
|
359 |
+
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
|
360 |
+
unconditional_state.update(self.get_streaming_state())
|
361 |
+
self.set_streaming_state(state)
|
362 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
363 |
+
else:
|
364 |
+
assert isinstance(cfg_conditions, dict)
|
365 |
+
condition_tensors = cfg_conditions
|
366 |
+
if condition_tensors:
|
367 |
+
# Preparing for CFG, predicting both conditional and unconditional logits.
|
368 |
+
sequence = torch.cat([sequence, sequence], dim=0)
|
369 |
+
all_logits = model(
|
370 |
+
sequence,
|
371 |
+
conditions=[], condition_tensors=condition_tensors)
|
372 |
+
if condition_tensors:
|
373 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
374 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
375 |
+
else:
|
376 |
+
logits = all_logits
|
377 |
+
|
378 |
+
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
379 |
+
logits = logits[..., -1] # [B x K x card]
|
380 |
+
|
381 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
382 |
+
if use_sampling and temp > 0.0:
|
383 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
384 |
+
if top_p > 0.0:
|
385 |
+
next_token = utils.sample_top_p(probs, p=top_p)
|
386 |
+
elif top_k > 0:
|
387 |
+
next_token = utils.sample_top_k(probs, k=top_k)
|
388 |
+
else:
|
389 |
+
next_token = utils.multinomial(probs, num_samples=1)
|
390 |
+
else:
|
391 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
392 |
+
|
393 |
+
return next_token
|
394 |
+
|
395 |
+
@torch.no_grad()
|
396 |
+
def generate(self,
|
397 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
398 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
399 |
+
num_samples: tp.Optional[int] = None,
|
400 |
+
max_gen_len: int = 256,
|
401 |
+
use_sampling: bool = True,
|
402 |
+
temp: float = 1.0,
|
403 |
+
top_k: int = 250,
|
404 |
+
top_p: float = 0.0,
|
405 |
+
cfg_coef: tp.Optional[float] = None,
|
406 |
+
two_step_cfg: tp.Optional[bool] = None,
|
407 |
+
remove_prompts: bool = False,
|
408 |
+
check: bool = False,
|
409 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
410 |
+
**kwargs) -> torch.Tensor:
|
411 |
+
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
412 |
+
be performed in a greedy fashion or using sampling with top K and top P strategies.
|
413 |
+
|
414 |
+
Args:
|
415 |
+
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T].
|
416 |
+
conditions_tensors (list of ConditioningAttributes, optional): List of conditions.
|
417 |
+
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given.
|
418 |
+
max_gen_len (int): Maximum generation length.
|
419 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
420 |
+
temp (float): Sampling temperature.
|
421 |
+
top_k (int): K for "top-k" sampling.
|
422 |
+
top_p (float): P for "top-p" sampling.
|
423 |
+
cfg_coeff (float, optional): Classifier-free guidance coefficient.
|
424 |
+
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation.
|
425 |
+
remove_prompts (bool): Whether to remove prompts from generation or not.
|
426 |
+
check (bool): Whether to apply further checks on generated sequence.
|
427 |
+
callback (Callback, optional): Callback function to report generation progress.
|
428 |
+
Returns:
|
429 |
+
torch.Tensor: Generated tokens.
|
430 |
+
"""
|
431 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
432 |
+
first_param = next(iter(self.parameters()))
|
433 |
+
device = first_param.device
|
434 |
+
|
435 |
+
# Checking all input shapes are consistent.
|
436 |
+
possible_num_samples = []
|
437 |
+
if num_samples is not None:
|
438 |
+
possible_num_samples.append(num_samples)
|
439 |
+
elif prompt is not None:
|
440 |
+
possible_num_samples.append(prompt.shape[0])
|
441 |
+
elif conditions:
|
442 |
+
possible_num_samples.append(len(conditions))
|
443 |
+
else:
|
444 |
+
possible_num_samples.append(1)
|
445 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
446 |
+
num_samples = possible_num_samples[0]
|
447 |
+
|
448 |
+
# below we create set of conditions: one conditional and one unconditional
|
449 |
+
# to do that we merge the regular condition together with the null condition
|
450 |
+
# we then do 1 forward pass instead of 2.
|
451 |
+
# the reason for that is two-fold:
|
452 |
+
# 1. it is about x2 faster than doing 2 forward passes
|
453 |
+
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
454 |
+
# We also support doing two different passes, in particular to ensure that
|
455 |
+
# the padding structure is exactly the same between train and test.
|
456 |
+
# With a batch size of 1, this can be slower though.
|
457 |
+
cfg_conditions: CFGConditions
|
458 |
+
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
459 |
+
if conditions:
|
460 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
461 |
+
if two_step_cfg:
|
462 |
+
cfg_conditions = (
|
463 |
+
self.condition_provider(self.condition_provider.tokenize(conditions)),
|
464 |
+
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
|
465 |
+
)
|
466 |
+
else:
|
467 |
+
conditions = conditions + null_conditions
|
468 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
469 |
+
cfg_conditions = self.condition_provider(tokenized)
|
470 |
+
else:
|
471 |
+
cfg_conditions = {}
|
472 |
+
|
473 |
+
if prompt is None:
|
474 |
+
assert num_samples > 0
|
475 |
+
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
476 |
+
|
477 |
+
B, K, T = prompt.shape
|
478 |
+
start_offset = T
|
479 |
+
assert start_offset < max_gen_len
|
480 |
+
|
481 |
+
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
482 |
+
# this token is used as default value for codes that are not generated yet
|
483 |
+
unknown_token = -1
|
484 |
+
|
485 |
+
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
486 |
+
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
487 |
+
# filling the gen_codes with the prompt if needed
|
488 |
+
gen_codes[..., :start_offset] = prompt
|
489 |
+
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
490 |
+
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
491 |
+
# retrieve the start_offset in the sequence:
|
492 |
+
# it is the first sequence step that contains the `start_offset` timestep
|
493 |
+
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
494 |
+
assert start_offset_sequence is not None
|
495 |
+
|
496 |
+
with self.streaming():
|
497 |
+
unconditional_state = self.get_streaming_state()
|
498 |
+
prev_offset = 0
|
499 |
+
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
500 |
+
for offset in range(start_offset_sequence, gen_sequence_len):
|
501 |
+
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
502 |
+
curr_sequence = gen_sequence[..., prev_offset:offset]
|
503 |
+
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
504 |
+
if check:
|
505 |
+
# check coherence between mask and sequence
|
506 |
+
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
507 |
+
# should never happen as gen_sequence is filled progressively
|
508 |
+
assert not (curr_sequence == unknown_token).any()
|
509 |
+
# sample next token from the model, next token shape is [B, K, 1]
|
510 |
+
next_token = self._sample_next_token(
|
511 |
+
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
512 |
+
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg)
|
513 |
+
# ensure the tokens that should be masked are properly set to special_token_id
|
514 |
+
# as the model never output special_token_id
|
515 |
+
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
516 |
+
next_token[~valid_mask] = self.special_token_id
|
517 |
+
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
518 |
+
# (then mask tokens should be left as is as well, which is correct)
|
519 |
+
gen_sequence[..., offset:offset+1] = torch.where(
|
520 |
+
gen_sequence[..., offset:offset+1] == unknown_token,
|
521 |
+
next_token, gen_sequence[..., offset:offset+1]
|
522 |
+
)
|
523 |
+
prev_offset = offset
|
524 |
+
if callback is not None:
|
525 |
+
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
526 |
+
unconditional_state.clear()
|
527 |
+
|
528 |
+
# ensure sequence has been entirely filled
|
529 |
+
assert not (gen_sequence == unknown_token).any()
|
530 |
+
# ensure gen_sequence pattern and mask are matching
|
531 |
+
# which means the gen_sequence is valid according to the pattern
|
532 |
+
assert (
|
533 |
+
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
534 |
+
).all()
|
535 |
+
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
536 |
+
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
537 |
+
|
538 |
+
# sanity checks over the returned codes and corresponding masks
|
539 |
+
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
540 |
+
assert (out_mask[..., :max_gen_len] == 1).all()
|
541 |
+
|
542 |
+
out_start_offset = start_offset if remove_prompts else 0
|
543 |
+
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
544 |
+
|
545 |
+
# ensure the returned codes are all valid
|
546 |
+
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
547 |
+
return out_codes
|
audiocraft/models/lm_magnet.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import math
|
9 |
+
import typing as tp
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from ..utils import utils
|
14 |
+
from ..modules.conditioners import (
|
15 |
+
ClassifierFreeGuidanceDropout,
|
16 |
+
ConditioningAttributes,
|
17 |
+
ConditionType,
|
18 |
+
)
|
19 |
+
from .lm import LMModel
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
ConditionTensors = tp.Dict[str, ConditionType]
|
23 |
+
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
24 |
+
|
25 |
+
|
26 |
+
class MagnetLMModel(LMModel):
|
27 |
+
"""Transformer-based, non-autoregressive model, operates on multiple streams of audio tokens (MAGNeT).
|
28 |
+
Args:
|
29 |
+
subcodes_context (int): The number of timesteps attended in the self-attention blocks of codebooks > 0.
|
30 |
+
When set to -1, attention is unrestricted and all timesteps are attended. Defaults to 5.
|
31 |
+
compression_model_framerate (int): frame rate of the audio tokenizer.
|
32 |
+
segment_duration (int): Sample length in seconds.
|
33 |
+
span_len (int): Determines the length of masking spans. This is the minimal length of consecutive masked tokens,
|
34 |
+
for both training and inference. Defaults to 3.
|
35 |
+
**kwargs: Additional parameters for the LMModel.
|
36 |
+
"""
|
37 |
+
def __init__(self, subcodes_context: int = 5, compression_model_framerate: int = 50,
|
38 |
+
segment_duration: int = 10, span_len: int = 3, **kwargs):
|
39 |
+
super().__init__(**kwargs)
|
40 |
+
self.causal = kwargs['causal']
|
41 |
+
self.subcodes_context = subcodes_context
|
42 |
+
self.span_len = span_len
|
43 |
+
self._build_attn_masks(compression_model_framerate=compression_model_framerate,
|
44 |
+
segment_duration=segment_duration,
|
45 |
+
num_heads=kwargs['num_heads'],
|
46 |
+
device=kwargs['device'], dtype=kwargs['dtype'])
|
47 |
+
|
48 |
+
def restricted_context_attn_mask(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
49 |
+
"""Creates a restricted attention mask (local attention map) where the context
|
50 |
+
is determined by self.subcodes_context.
|
51 |
+
Args:
|
52 |
+
seq_len (int): token sequence length.
|
53 |
+
device (torch.device): device of the output tensor.
|
54 |
+
dtype (torch.dtype): data type of the output tensor.
|
55 |
+
Returns:
|
56 |
+
torch.Tensor: The restricted attention mask.
|
57 |
+
"""
|
58 |
+
# Return a context restricted non-causal att mask
|
59 |
+
queries_pos = torch.arange(seq_len, device=device).view(-1, 1)
|
60 |
+
keys_pos = torch.arange(seq_len, device=device).view(1, -1)
|
61 |
+
|
62 |
+
delta = queries_pos - keys_pos
|
63 |
+
valid = torch.abs(delta) <= self.subcodes_context
|
64 |
+
return torch.where(
|
65 |
+
valid,
|
66 |
+
torch.zeros([], device=device, dtype=dtype),
|
67 |
+
torch.full([], float('-inf'), device=device, dtype=dtype))
|
68 |
+
|
69 |
+
def _stage_attn_mask(self, stage: int, seq_len: int, num_heads: int,
|
70 |
+
device: torch.device, dtype: torch.dtype) -> tp.Optional[torch.Tensor]:
|
71 |
+
"""Creates a restricted attention mask given the stage (codebook index).
|
72 |
+
Args:
|
73 |
+
stage (int): The codebook index. Takes values in [0, n_q].
|
74 |
+
seq_len (int): Token sequence length.
|
75 |
+
num_heads (int): Num transformer attention heads.
|
76 |
+
device (torch.device): device of the output tensor.
|
77 |
+
dtype (torch.dtype): data type of the output tensor.
|
78 |
+
Returns:
|
79 |
+
torch.Tensor: Either a restricted attention mask or None if stage attention is unrestricted.
|
80 |
+
"""
|
81 |
+
sa_mask = None
|
82 |
+
|
83 |
+
if stage > 0 and self.subcodes_context > -1:
|
84 |
+
# parallel - non-causal - with restricted subcodes context
|
85 |
+
sa_mask = self.restricted_context_attn_mask(seq_len, device=device, dtype=dtype)
|
86 |
+
|
87 |
+
if sa_mask is not None:
|
88 |
+
# Repeat for each attention head
|
89 |
+
sa_mask = sa_mask.repeat((1, num_heads, 1, 1))
|
90 |
+
|
91 |
+
# align8 to enable memory efficient attention
|
92 |
+
MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR = 8
|
93 |
+
seq_len_aligned = \
|
94 |
+
int(np.ceil(seq_len / MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR)) * MEMORY_EFFICIENT_ATTN_ALIGN_FACTOR
|
95 |
+
|
96 |
+
sa_mask_aligned = torch.zeros((1, num_heads, seq_len_aligned, seq_len_aligned), device=device, dtype=dtype)
|
97 |
+
sa_mask_aligned[..., :seq_len, :seq_len] = sa_mask
|
98 |
+
sa_mask = sa_mask_aligned
|
99 |
+
|
100 |
+
return sa_mask
|
101 |
+
|
102 |
+
def _build_attn_masks(self, compression_model_framerate: int, segment_duration: int, num_heads: int,
|
103 |
+
device: torch.device, dtype: torch.dtype):
|
104 |
+
"""Construct attention mask per stage. For each of the RVQ codebook levels in the [0, n_q] range,
|
105 |
+
either a local attention map or None would be stored as an entry in the self.attn_mask_per_stage list.
|
106 |
+
Args:
|
107 |
+
compression_model_framerate (int): The frame rate of the tokenizer.
|
108 |
+
segment_duration (int): Sample length in seconds.
|
109 |
+
num_heads (int): Num transformer attention heads.
|
110 |
+
device (torch.device): device of the output tensor.
|
111 |
+
dtype (torch.dtype): data type of the output tensor.
|
112 |
+
"""
|
113 |
+
seq_len = compression_model_framerate * segment_duration
|
114 |
+
self.attn_mask_per_stage = [self._stage_attn_mask(stage, seq_len, num_heads,
|
115 |
+
device, dtype) for stage in range(self.n_q)]
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
def generate(self,
|
119 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
120 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
121 |
+
num_samples: tp.Optional[int] = None,
|
122 |
+
max_gen_len: int = 256,
|
123 |
+
use_sampling: bool = True,
|
124 |
+
temp: float = 1.0,
|
125 |
+
top_k: int = 250,
|
126 |
+
top_p: float = 0.0,
|
127 |
+
cfg_coef: tp.Optional[float] = None,
|
128 |
+
two_step_cfg: tp.Optional[bool] = None,
|
129 |
+
remove_prompts: bool = False,
|
130 |
+
check: bool = False,
|
131 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
132 |
+
**kwargs) -> torch.Tensor:
|
133 |
+
|
134 |
+
assert cfg_coef is None, "Unsupported in MAGNeT. Use max_cfg_coef,min_cfg_coef instead."
|
135 |
+
assert two_step_cfg is None, "MAGNeT currently doesn't support two step classifier-free-guidance."
|
136 |
+
assert remove_prompts is False, "MAGNeT currently doesn't support the remove_prompts arg."
|
137 |
+
assert check is False, "MAGNeT currently doesn't support the check arg."
|
138 |
+
# Call the MAGNeT-specific generation method
|
139 |
+
return self._generate_magnet(prompt=prompt,
|
140 |
+
conditions=conditions,
|
141 |
+
num_samples=num_samples,
|
142 |
+
max_gen_len=max_gen_len,
|
143 |
+
use_sampling=use_sampling,
|
144 |
+
temp=temp,
|
145 |
+
top_k=top_k,
|
146 |
+
top_p=top_p,
|
147 |
+
callback=callback, **kwargs)
|
148 |
+
|
149 |
+
@torch.no_grad()
|
150 |
+
def _generate_magnet(self,
|
151 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
152 |
+
conditions: tp.List[ConditioningAttributes] = [],
|
153 |
+
num_samples: tp.Optional[int] = None,
|
154 |
+
max_gen_len: int = 256,
|
155 |
+
use_sampling: bool = True,
|
156 |
+
temp: float = 3.0,
|
157 |
+
top_k: int = 0,
|
158 |
+
top_p: float = 0.9,
|
159 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None,
|
160 |
+
max_cfg_coef: float = 10.0,
|
161 |
+
min_cfg_coef: float = 1.0,
|
162 |
+
decoding_steps: tp.List[int] = [20, 10, 10, 10],
|
163 |
+
anneal_temp: bool = True,
|
164 |
+
span_scoring='max',
|
165 |
+
span_arrangement='nonoverlap') -> torch.Tensor:
|
166 |
+
"""Generate audio tokens given textual conditions, and optionally given audio prompts,
|
167 |
+
by running MAGNeT's iterative decoding algorithm for each of the n_q RVQ levels.
|
168 |
+
Args:
|
169 |
+
prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
|
170 |
+
conditions (list of ConditioningAttributes): List of conditions.
|
171 |
+
num_samples (int): Number of samples to generate when no prompt and no conditions are given.
|
172 |
+
max_gen_len (int): Maximum generation length.
|
173 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
174 |
+
temp (float): Initial sampling temperature.
|
175 |
+
top_k (int): k for "top-k" sampling.
|
176 |
+
top_p (float): p for "top-p" sampling.
|
177 |
+
callback (Callback): Callback function to report generation progress.
|
178 |
+
max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
|
179 |
+
min_clsfg_coef (float): Final coefficient used for classifier free guidance.
|
180 |
+
decoding_steps (list of n_q ints): The number of iterative decoding steps,
|
181 |
+
for each of the n_q RVQ codebooks.
|
182 |
+
anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
|
183 |
+
span_scoring (str): Use the maximum probability of each span ('max')
|
184 |
+
or the product of probabilities ('prod').
|
185 |
+
span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
|
186 |
+
in the masking scheme.
|
187 |
+
Returns:
|
188 |
+
torch.Tensor: Generated tokens.
|
189 |
+
"""
|
190 |
+
assert not self.training, "generation shouldn't be used in training mode."
|
191 |
+
first_param = next(iter(self.parameters()))
|
192 |
+
device = first_param.device
|
193 |
+
|
194 |
+
# Checking all input shapes are consistent.
|
195 |
+
possible_num_samples = []
|
196 |
+
if num_samples is not None:
|
197 |
+
possible_num_samples.append(num_samples)
|
198 |
+
elif prompt is not None:
|
199 |
+
possible_num_samples.append(prompt.shape[0])
|
200 |
+
elif conditions:
|
201 |
+
possible_num_samples.append(len(conditions))
|
202 |
+
else:
|
203 |
+
possible_num_samples.append(1)
|
204 |
+
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes"
|
205 |
+
num_samples = possible_num_samples[0]
|
206 |
+
|
207 |
+
# below we create set of conditions: one conditional and one unconditional
|
208 |
+
# to do that we merge the regular condition together with the null condition
|
209 |
+
# we then do 1 forward pass instead of 2.
|
210 |
+
cfg_conditions: tp.Optional[ConditionTensors]
|
211 |
+
if conditions:
|
212 |
+
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
213 |
+
conditions = conditions + null_conditions
|
214 |
+
tokenized = self.condition_provider.tokenize(conditions)
|
215 |
+
cfg_conditions = self.condition_provider(tokenized)
|
216 |
+
else:
|
217 |
+
cfg_conditions = {}
|
218 |
+
|
219 |
+
if prompt is None:
|
220 |
+
assert num_samples > 0
|
221 |
+
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
222 |
+
|
223 |
+
B, K, prompt_length = prompt.shape
|
224 |
+
start_offset = prompt_length
|
225 |
+
assert start_offset < max_gen_len
|
226 |
+
|
227 |
+
mask_id = self.special_token_id
|
228 |
+
|
229 |
+
# we generate codes with a fixed sequence length
|
230 |
+
shape = (B, K, max_gen_len)
|
231 |
+
|
232 |
+
gen_codes = torch.full(shape, mask_id, dtype=torch.long, device=device)
|
233 |
+
# filling the gen_codes with the prompt if needed
|
234 |
+
gen_codes[..., :start_offset] = prompt
|
235 |
+
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
236 |
+
gen_sequence = gen_codes
|
237 |
+
|
238 |
+
curr_step = 0
|
239 |
+
for stage, n_steps in zip(range(self.n_q), decoding_steps):
|
240 |
+
gen_sequence, curr_step = self._generate_stage(gen_sequence,
|
241 |
+
cfg_conditions,
|
242 |
+
stage=stage,
|
243 |
+
device=device,
|
244 |
+
prompt_length=prompt_length,
|
245 |
+
prompt=prompt,
|
246 |
+
temp=temp,
|
247 |
+
max_cfg_coef=max_cfg_coef,
|
248 |
+
min_cfg_coef=min_cfg_coef,
|
249 |
+
top_k=top_k,
|
250 |
+
top_p=top_p,
|
251 |
+
timesteps=n_steps,
|
252 |
+
anneal_temp=anneal_temp,
|
253 |
+
span_scoring=span_scoring,
|
254 |
+
use_sampling=use_sampling,
|
255 |
+
span_arrangement=span_arrangement,
|
256 |
+
curr_step=curr_step,
|
257 |
+
total_steps=sum(decoding_steps),
|
258 |
+
callback=callback)
|
259 |
+
|
260 |
+
return gen_sequence
|
261 |
+
|
262 |
+
@torch.no_grad()
|
263 |
+
def _generate_stage(self,
|
264 |
+
gen_sequence: torch.Tensor,
|
265 |
+
condition_tensors: tp.Optional[ConditionTensors],
|
266 |
+
stage: int,
|
267 |
+
device: torch.device,
|
268 |
+
prompt_length: int = 0,
|
269 |
+
prompt: tp.Optional[torch.Tensor] = None,
|
270 |
+
use_sampling: bool = True,
|
271 |
+
temp: float = 3.0,
|
272 |
+
max_cfg_coef: float = 10.0,
|
273 |
+
min_cfg_coef: float = 1.0,
|
274 |
+
top_k: int = 0,
|
275 |
+
top_p: float = 0.0,
|
276 |
+
timesteps: int = 10,
|
277 |
+
anneal_temp: bool = True,
|
278 |
+
span_scoring: str = 'max',
|
279 |
+
span_arrangement: str = 'nonoverlap',
|
280 |
+
curr_step: int = 0,
|
281 |
+
total_steps: int = 0,
|
282 |
+
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> tp.Tuple[torch.Tensor, int]:
|
283 |
+
"""Generate audio tokens of a single RVQ level (stage), given the previously generated stages,
|
284 |
+
and the textual conditions.
|
285 |
+
Args:
|
286 |
+
gen_sequence (torch.Tensor): Previously generated tokens.
|
287 |
+
condition_tensors (tp.Optional[ConditionTensors]): pre-computed conditioning tensors.
|
288 |
+
stage (int): RVQ level to generate.
|
289 |
+
device (torch.device): device of the output tensor.
|
290 |
+
prompt_length (int): Temporal length of the audio prompt.
|
291 |
+
prompt (torch.Tensor): Prompt tokens of shape [B, K, T].
|
292 |
+
use_sampling (bool): Whether to use a sampling strategy or not.
|
293 |
+
temp (float): Initial sampling temperature.
|
294 |
+
max_clsfg_coef (float): Initial coefficient used for classifier free guidance.
|
295 |
+
min_clsfg_coef (float): Final coefficient used for classifier free guidance.
|
296 |
+
top_k (int): k for "top-k" sampling.
|
297 |
+
top_p (float): p for "top-p" sampling.
|
298 |
+
timesteps (int): Number of iterative decoding steps.
|
299 |
+
anneal_temp (bool): When set to True, softmax temperature will be linearly decayed to zero, at each stage.
|
300 |
+
span_scoring (str): Use the maximum probability of each span ('max')
|
301 |
+
or the product of probabilities ('prod').
|
302 |
+
span_arrangement (str): Use either non-overlapping spans ('nonoverlap') or overlapping spans ('stride1').
|
303 |
+
in the masking scheme.
|
304 |
+
curr_step (int): Global iterative decoding step counter.
|
305 |
+
total_steps (int): Total decoding steps.
|
306 |
+
callback (Callback): Callback function to report generation progress.
|
307 |
+
Returns:
|
308 |
+
tuple(torch.Tensor, int): Generated tokens and the current decoding step counter.
|
309 |
+
"""
|
310 |
+
B, K, T = gen_sequence.shape
|
311 |
+
shape = (B, 1, T) # generating a single codebook per stage
|
312 |
+
|
313 |
+
mask_id = self.special_token_id
|
314 |
+
stage_gen_seq = torch.full(shape, mask_id, dtype=torch.long, device=device)
|
315 |
+
|
316 |
+
assert span_arrangement == 'nonoverlap' or span_arrangement == 'stride1'
|
317 |
+
chunk_masking = self.span_len > 1 and span_arrangement == 'nonoverlap'
|
318 |
+
|
319 |
+
DONT_REMASK_ME_SCORE = -1e4
|
320 |
+
|
321 |
+
model = self if self._fsdp is None else self._fsdp
|
322 |
+
|
323 |
+
if chunk_masking:
|
324 |
+
# span-wise scores
|
325 |
+
n_chunks = T // self.span_len
|
326 |
+
if T % self.span_len != 0:
|
327 |
+
# trim sequence ending to achieve a multiple of span_len
|
328 |
+
T = self.span_len * n_chunks
|
329 |
+
gen_sequence = gen_sequence[..., :T]
|
330 |
+
stage_gen_seq = stage_gen_seq[..., :T]
|
331 |
+
|
332 |
+
chunked_shape = (B, 1, n_chunks)
|
333 |
+
n_prompt_chunks = prompt_length // self.span_len
|
334 |
+
scores = torch.zeros(chunked_shape, dtype=torch.float32, device=device)
|
335 |
+
scores[..., :n_prompt_chunks] = DONT_REMASK_ME_SCORE
|
336 |
+
num_chunks_to_gen = n_chunks - n_prompt_chunks
|
337 |
+
else:
|
338 |
+
# token-wise scores
|
339 |
+
scores = torch.zeros(shape, dtype=torch.float32, device=device)
|
340 |
+
scores[..., :prompt_length] = DONT_REMASK_ME_SCORE
|
341 |
+
gen_T = T - prompt_length
|
342 |
+
|
343 |
+
# run MAGNeT iterative decoding for "timesteps" iterations
|
344 |
+
for timestep, steps_left in zip(torch.linspace(0, 1, timesteps, device=device), reversed(range(timesteps))):
|
345 |
+
|
346 |
+
mask_p = torch.cos(timestep * math.pi * 0.5)
|
347 |
+
|
348 |
+
if chunk_masking:
|
349 |
+
num_masked = max(int((mask_p * num_chunks_to_gen).item()), 1)
|
350 |
+
else:
|
351 |
+
num_masked = max(int((mask_p * gen_T).item()), 1)
|
352 |
+
|
353 |
+
# masking
|
354 |
+
run_lps_masking = (span_arrangement == 'stride1') and self.span_len > 1
|
355 |
+
if run_lps_masking:
|
356 |
+
# masking of the k least probable overlapping (stride 1) spans
|
357 |
+
mask = torch.concat((
|
358 |
+
[self._least_probable_span_masking(scores[[i], :, :], num_masked).to(device)
|
359 |
+
for i in range(B)]), dim=0)
|
360 |
+
stage_gen_seq[mask] = mask_id
|
361 |
+
else:
|
362 |
+
# masking of the k least probable non-overlapping spans
|
363 |
+
masked = scores.topk(num_masked, dim=-1).indices
|
364 |
+
if chunk_masking:
|
365 |
+
chunks_mask = torch.full(chunked_shape, False, dtype=torch.bool, device=device)
|
366 |
+
chunks_mask = chunks_mask.scatter(2, masked, True)
|
367 |
+
mask = torch.repeat_interleave(chunks_mask, self.span_len, dim=-1)
|
368 |
+
stage_gen_seq[mask] = mask_id
|
369 |
+
else:
|
370 |
+
stage_gen_seq = stage_gen_seq.scatter(2, masked, mask_id)
|
371 |
+
|
372 |
+
if prompt is not None:
|
373 |
+
stage_gen_seq[..., :prompt_length] = prompt[:, stage, :].unsqueeze(1)
|
374 |
+
|
375 |
+
gen_sequence[:, [stage], :] = stage_gen_seq
|
376 |
+
if condition_tensors:
|
377 |
+
# duplicate input for classifier free guidance
|
378 |
+
sequence = torch.cat([gen_sequence, gen_sequence], dim=0)
|
379 |
+
|
380 |
+
all_logits = model(sequence, [], condition_tensors, stage=stage)
|
381 |
+
|
382 |
+
if condition_tensors:
|
383 |
+
# classifier free guidance with annealing
|
384 |
+
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
385 |
+
clsfg_coef = float(mask_p) * max_cfg_coef + (1 - float(mask_p)) * min_cfg_coef
|
386 |
+
logits = uncond_logits + (cond_logits - uncond_logits) * clsfg_coef
|
387 |
+
else:
|
388 |
+
logits = all_logits
|
389 |
+
|
390 |
+
# temperature annealing - linear
|
391 |
+
t = temp * (steps_left / timesteps) if anneal_temp else temp
|
392 |
+
|
393 |
+
# sampling
|
394 |
+
logits = logits[:, stage, :, :].unsqueeze(1)
|
395 |
+
probs = torch.softmax(logits / max(t, 1e-2), dim=-1)
|
396 |
+
if use_sampling:
|
397 |
+
if top_p > 0.0:
|
398 |
+
sampled_tokens = utils.sample_top_p(probs, p=top_p)
|
399 |
+
elif top_k > 0:
|
400 |
+
sampled_tokens = utils.sample_top_k(probs, k=top_k)
|
401 |
+
else:
|
402 |
+
sampled_tokens = utils.multinomial(probs, num_samples=1)
|
403 |
+
else:
|
404 |
+
sampled_tokens = torch.argmax(logits, dim=-1, keepdim=True)
|
405 |
+
|
406 |
+
# place mask_id token in each of the masked positions
|
407 |
+
mask = stage_gen_seq == mask_id
|
408 |
+
stage_gen_seq = torch.where(mask, sampled_tokens[..., 0], stage_gen_seq)
|
409 |
+
gen_sequence[:, [stage], :] = stage_gen_seq
|
410 |
+
|
411 |
+
# get probs of sampled tokens
|
412 |
+
sampled_probs = torch.gather(probs, 3, sampled_tokens)[..., 0]
|
413 |
+
|
414 |
+
# span scoring
|
415 |
+
if chunk_masking:
|
416 |
+
if span_scoring == 'max':
|
417 |
+
# max in linear space
|
418 |
+
scores = 1 - torch.max(sampled_probs.reshape((B, 1, n_chunks, -1)), dim=-1)[0]
|
419 |
+
elif span_scoring == 'prod':
|
420 |
+
# prod in log space
|
421 |
+
scores = torch.sum(-torch.log(sampled_probs).reshape((B, 1, n_chunks, -1)), dim=-1)
|
422 |
+
else:
|
423 |
+
raise NotImplementedError
|
424 |
+
else:
|
425 |
+
# prod in log space for lps masking (stride1)
|
426 |
+
scores = -torch.log(sampled_probs)
|
427 |
+
|
428 |
+
# Fix unmasked tokens by placing inf probs (-inf scores)
|
429 |
+
if chunk_masking:
|
430 |
+
scores = scores.masked_fill(~chunks_mask, DONT_REMASK_ME_SCORE)
|
431 |
+
else:
|
432 |
+
scores = scores.masked_fill(~mask, DONT_REMASK_ME_SCORE)
|
433 |
+
|
434 |
+
if callback is not None:
|
435 |
+
curr_step += 1
|
436 |
+
callback(curr_step, total_steps)
|
437 |
+
|
438 |
+
return gen_sequence, curr_step
|
439 |
+
|
440 |
+
def _construct_spans_mask(self, span_starts: torch.Tensor, T: int, device: torch.device) -> torch.Tensor:
|
441 |
+
"""Build a [1x1xT] boolean mask consists of overlapping spans of True values, where
|
442 |
+
span_starts defines the initial index of each span, and the span length is
|
443 |
+
defined by self.span_len.
|
444 |
+
Args:
|
445 |
+
span_starts (torch.Tensor): Boolean mask determines the temporal location of each span start.
|
446 |
+
T (int): Sequence length.
|
447 |
+
device (torch.device): device of the output tensor.
|
448 |
+
Returns:
|
449 |
+
torch.Tensor: Spans mask of shape [1x1xT]
|
450 |
+
"""
|
451 |
+
mask = torch.full((1, 1, T), False, device=device)
|
452 |
+
mask[:, :, span_starts] = True
|
453 |
+
shifted_mask = mask.clone()
|
454 |
+
for _ in range(self.span_len - 1):
|
455 |
+
shifted_mask = torch.concat((torch.full((1, 1, 1), False, device=device), shifted_mask[:, :, :-1]), dim=-1)
|
456 |
+
mask = torch.logical_or(mask, shifted_mask)
|
457 |
+
return mask
|
458 |
+
|
459 |
+
def _least_probable_span_masking(self, scores: torch.Tensor, num_masked_trg: int) -> torch.Tensor:
|
460 |
+
"""Construct a [1x1xT] boolean mask, consists of the u least probable spans,
|
461 |
+
where the token probability is determined by -scores, and the total
|
462 |
+
number of masked tokens is as closest as possible to num_masked_trg.
|
463 |
+
Find u using binary search.
|
464 |
+
Args:
|
465 |
+
scores (torch.Tensor): Per token score [-log(prob)]
|
466 |
+
num_masked_trg: int: The desired amount of tokens to be masked.
|
467 |
+
Returns:
|
468 |
+
torch.Tensor: Spans mask of shape [1x1xT]
|
469 |
+
"""
|
470 |
+
T = scores.shape[-1]
|
471 |
+
device = scores.device
|
472 |
+
scores_unfolded = scores.unfold(2, self.span_len, 1)
|
473 |
+
# Span score is the product of probs (sum in log space)
|
474 |
+
span_scores = scores_unfolded.sum(dim=-1)
|
475 |
+
spans_by_scores = torch.argsort(span_scores[0, 0], descending=True)
|
476 |
+
|
477 |
+
num_masked_trg = max(num_masked_trg, self.span_len)
|
478 |
+
|
479 |
+
# Binary search for u - the number least probable overlapping masked spans s.t.
|
480 |
+
# the total masking rate is the closest to num_masked_trg / T.
|
481 |
+
min_u = num_masked_trg // self.span_len
|
482 |
+
max_u = num_masked_trg - self.span_len + 1
|
483 |
+
mid = round(0.5 * (min_u + max_u))
|
484 |
+
|
485 |
+
if mid == min_u or mid == max_u:
|
486 |
+
return self._construct_spans_mask(spans_by_scores[:mid], T, device)
|
487 |
+
|
488 |
+
while mid > min_u and mid < max_u:
|
489 |
+
mask = self._construct_spans_mask(spans_by_scores[:mid], T, device)
|
490 |
+
n_masked = mask.sum()
|
491 |
+
if n_masked > num_masked_trg:
|
492 |
+
max_u = mid
|
493 |
+
mid = round(0.5 * (min_u + max_u))
|
494 |
+
else:
|
495 |
+
min_u = mid
|
496 |
+
mid = round(0.5 * (min_u + max_u))
|
497 |
+
|
498 |
+
return mask
|
audiocraft/models/loaders.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Utility functions to load from the checkpoints.
|
9 |
+
Each checkpoint is a torch.saved dict with the following keys:
|
10 |
+
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
11 |
+
to rebuild the object using the audiocraft.models.builders functions,
|
12 |
+
- 'model_best_state': a readily loadable best state for the model, including
|
13 |
+
the conditioner. The model obtained from `xp.cfg` should be compatible
|
14 |
+
with this state dict. In the case of a LM, the encodec model would not be
|
15 |
+
bundled along but instead provided separately.
|
16 |
+
|
17 |
+
Those functions also support loading from a remote location with the Torch Hub API.
|
18 |
+
They also support overriding some parameters, in particular the device and dtype
|
19 |
+
of the returned model.
|
20 |
+
"""
|
21 |
+
|
22 |
+
from pathlib import Path
|
23 |
+
from huggingface_hub import hf_hub_download
|
24 |
+
import typing as tp
|
25 |
+
import os
|
26 |
+
|
27 |
+
from omegaconf import OmegaConf, DictConfig
|
28 |
+
import torch
|
29 |
+
|
30 |
+
import audiocraft
|
31 |
+
from . import builders
|
32 |
+
from .encodec import CompressionModel
|
33 |
+
|
34 |
+
|
35 |
+
def get_audiocraft_cache_dir() -> tp.Optional[str]:
|
36 |
+
return os.environ.get('AUDIOCRAFT_CACHE_DIR', None)
|
37 |
+
|
38 |
+
|
39 |
+
def _get_state_dict(
|
40 |
+
file_or_url_or_id: tp.Union[Path, str],
|
41 |
+
filename: tp.Optional[str] = None,
|
42 |
+
device='cpu',
|
43 |
+
cache_dir: tp.Optional[str] = None,
|
44 |
+
):
|
45 |
+
if cache_dir is None:
|
46 |
+
cache_dir = get_audiocraft_cache_dir()
|
47 |
+
# Return the state dict either from a file or url
|
48 |
+
file_or_url_or_id = str(file_or_url_or_id)
|
49 |
+
assert isinstance(file_or_url_or_id, str)
|
50 |
+
|
51 |
+
if os.path.isfile(file_or_url_or_id):
|
52 |
+
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
+
|
54 |
+
if os.path.isdir(file_or_url_or_id):
|
55 |
+
file = f"{file_or_url_or_id}/{filename}"
|
56 |
+
return torch.load(file, map_location=device)
|
57 |
+
|
58 |
+
elif file_or_url_or_id.startswith('https://'):
|
59 |
+
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
60 |
+
|
61 |
+
else:
|
62 |
+
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
63 |
+
|
64 |
+
file = hf_hub_download(
|
65 |
+
repo_id=file_or_url_or_id, filename=filename, cache_dir=cache_dir,
|
66 |
+
library_name="audiocraft", library_version=audiocraft.__version__)
|
67 |
+
return torch.load(file, map_location=device)
|
68 |
+
|
69 |
+
|
70 |
+
def load_compression_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
71 |
+
return _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
72 |
+
|
73 |
+
|
74 |
+
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
75 |
+
pkg = load_compression_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
76 |
+
if 'pretrained' in pkg:
|
77 |
+
return CompressionModel.get_pretrained(pkg['pretrained'], device=device)
|
78 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
79 |
+
cfg.device = str(device)
|
80 |
+
model = builders.get_compression_model(cfg)
|
81 |
+
model.load_state_dict(pkg['best_state'])
|
82 |
+
model.eval()
|
83 |
+
return model
|
84 |
+
|
85 |
+
|
86 |
+
def load_lm_model_ckpt(file_or_url_or_id: tp.Union[Path, str], cache_dir: tp.Optional[str] = None):
|
87 |
+
return _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
88 |
+
|
89 |
+
|
90 |
+
def _delete_param(cfg: DictConfig, full_name: str):
|
91 |
+
parts = full_name.split('.')
|
92 |
+
for part in parts[:-1]:
|
93 |
+
if part in cfg:
|
94 |
+
cfg = cfg[part]
|
95 |
+
else:
|
96 |
+
return
|
97 |
+
OmegaConf.set_struct(cfg, False)
|
98 |
+
if parts[-1] in cfg:
|
99 |
+
del cfg[parts[-1]]
|
100 |
+
OmegaConf.set_struct(cfg, True)
|
101 |
+
|
102 |
+
|
103 |
+
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
104 |
+
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
105 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
106 |
+
cfg.device = str(device)
|
107 |
+
if cfg.device == 'cpu':
|
108 |
+
cfg.dtype = 'float32'
|
109 |
+
else:
|
110 |
+
cfg.dtype = 'float16'
|
111 |
+
_delete_param(cfg, 'conditioners.self_wav.chroma_stem.cache_path')
|
112 |
+
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
113 |
+
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
114 |
+
model = builders.get_lm_model(cfg)
|
115 |
+
model.load_state_dict(pkg['best_state'])
|
116 |
+
model.eval()
|
117 |
+
model.cfg = cfg
|
118 |
+
return model
|
119 |
+
|
120 |
+
|
121 |
+
def load_lm_model_magnet(file_or_url_or_id: tp.Union[Path, str], compression_model_frame_rate: int,
|
122 |
+
device='cpu', cache_dir: tp.Optional[str] = None):
|
123 |
+
pkg = load_lm_model_ckpt(file_or_url_or_id, cache_dir=cache_dir)
|
124 |
+
cfg = OmegaConf.create(pkg['xp.cfg'])
|
125 |
+
cfg.device = str(device)
|
126 |
+
if cfg.device == 'cpu':
|
127 |
+
cfg.dtype = 'float32'
|
128 |
+
else:
|
129 |
+
cfg.dtype = 'float16'
|
130 |
+
_delete_param(cfg, 'conditioners.args.merge_text_conditions_p')
|
131 |
+
_delete_param(cfg, 'conditioners.args.drop_desc_p')
|
132 |
+
|
133 |
+
cfg.transformer_lm.compression_model_framerate = compression_model_frame_rate
|
134 |
+
cfg.transformer_lm.segment_duration = cfg.dataset.segment_duration
|
135 |
+
cfg.transformer_lm.span_len = cfg.masking.span_len
|
136 |
+
|
137 |
+
# MAGNeT models v1 support only xformers backend.
|
138 |
+
from audiocraft.modules.transformer import set_efficient_attention_backend
|
139 |
+
if cfg.transformer_lm.memory_efficient:
|
140 |
+
set_efficient_attention_backend("xformers")
|
141 |
+
|
142 |
+
model = builders.get_lm_model(cfg)
|
143 |
+
model.load_state_dict(pkg['best_state'])
|
144 |
+
model.eval()
|
145 |
+
model.cfg = cfg
|
146 |
+
return model
|
147 |
+
|
148 |
+
|
149 |
+
def load_mbd_ckpt(file_or_url_or_id: tp.Union[Path, str],
|
150 |
+
filename: tp.Optional[str] = None,
|
151 |
+
cache_dir: tp.Optional[str] = None):
|
152 |
+
return _get_state_dict(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
153 |
+
|
154 |
+
|
155 |
+
def load_diffusion_models(file_or_url_or_id: tp.Union[Path, str],
|
156 |
+
device='cpu',
|
157 |
+
filename: tp.Optional[str] = None,
|
158 |
+
cache_dir: tp.Optional[str] = None):
|
159 |
+
pkg = load_mbd_ckpt(file_or_url_or_id, filename=filename, cache_dir=cache_dir)
|
160 |
+
models = []
|
161 |
+
processors = []
|
162 |
+
cfgs = []
|
163 |
+
sample_rate = pkg['sample_rate']
|
164 |
+
for i in range(pkg['n_bands']):
|
165 |
+
cfg = pkg[i]['cfg']
|
166 |
+
model = builders.get_diffusion_model(cfg)
|
167 |
+
model_dict = pkg[i]['model_state']
|
168 |
+
model.load_state_dict(model_dict)
|
169 |
+
model.to(device)
|
170 |
+
processor = builders.get_processor(cfg=cfg.processor, sample_rate=sample_rate)
|
171 |
+
processor_dict = pkg[i]['processor_state']
|
172 |
+
processor.load_state_dict(processor_dict)
|
173 |
+
processor.to(device)
|
174 |
+
models.append(model)
|
175 |
+
processors.append(processor)
|
176 |
+
cfgs.append(cfg)
|
177 |
+
return models, processors, cfgs
|
audiocraft/models/magnet.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Main model for using MAGNeT. This will combine all the required components
|
9 |
+
and provide easy access to the generation API.
|
10 |
+
"""
|
11 |
+
import typing as tp
|
12 |
+
import torch
|
13 |
+
|
14 |
+
from .genmodel import BaseGenModel
|
15 |
+
from .loaders import load_compression_model, load_lm_model_magnet
|
16 |
+
|
17 |
+
|
18 |
+
class MAGNeT(BaseGenModel):
|
19 |
+
"""MAGNeT main model with convenient generation API.
|
20 |
+
Args:
|
21 |
+
See MusicGen class.
|
22 |
+
"""
|
23 |
+
def __init__(self, **kwargs):
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
# MAGNeT operates over a fixed sequence length defined in it's config.
|
26 |
+
self.duration = self.lm.cfg.dataset.segment_duration
|
27 |
+
self.set_generation_params()
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def get_pretrained(name: str = 'facebook/magnet-small-10secs', device=None):
|
31 |
+
"""Return pretrained model, we provide six models:
|
32 |
+
- facebook/magnet-small-10secs (300M), text to music, 10-second audio samples.
|
33 |
+
# see: https://huggingface.co/facebook/magnet-small-10secs
|
34 |
+
- facebook/magnet-medium-10secs (1.5B), text to music, 10-second audio samples.
|
35 |
+
# see: https://huggingface.co/facebook/magnet-medium-10secs
|
36 |
+
- facebook/magnet-small-30secs (300M), text to music, 30-second audio samples.
|
37 |
+
# see: https://huggingface.co/facebook/magnet-small-30secs
|
38 |
+
- facebook/magnet-medium-30secs (1.5B), text to music, 30-second audio samples.
|
39 |
+
# see: https://huggingface.co/facebook/magnet-medium-30secs
|
40 |
+
- facebook/audio-magnet-small (300M), text to sound-effect (10-second samples).
|
41 |
+
# see: https://huggingface.co/facebook/audio-magnet-small
|
42 |
+
- facebook/audio-magnet-medium (1.5B), text to sound-effect (10-second samples).
|
43 |
+
# see: https://huggingface.co/facebook/audio-magnet-medium
|
44 |
+
"""
|
45 |
+
if device is None:
|
46 |
+
if torch.cuda.device_count():
|
47 |
+
device = 'cuda'
|
48 |
+
else:
|
49 |
+
device = 'cpu'
|
50 |
+
|
51 |
+
compression_model = load_compression_model(name, device=device)
|
52 |
+
lm = load_lm_model_magnet(name, compression_model_frame_rate=int(compression_model.frame_rate), device=device)
|
53 |
+
|
54 |
+
if 'self_wav' in lm.condition_provider.conditioners:
|
55 |
+
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
56 |
+
|
57 |
+
kwargs = {'name': name, 'compression_model': compression_model, 'lm': lm}
|
58 |
+
return MAGNeT(**kwargs)
|
59 |
+
|
60 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 0,
|
61 |
+
top_p: float = 0.9, temperature: float = 3.0,
|
62 |
+
max_cfg_coef: float = 10.0, min_cfg_coef: float = 1.0,
|
63 |
+
decoding_steps: tp.List[int] = [20, 10, 10, 10],
|
64 |
+
span_arrangement: str = 'nonoverlap'):
|
65 |
+
"""Set the generation parameters for MAGNeT.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
69 |
+
top_k (int, optional): top_k used for sampling. Defaults to 0.
|
70 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.9.
|
71 |
+
temperature (float, optional): Initial softmax temperature parameter. Defaults to 3.0.
|
72 |
+
max_cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 10.0.
|
73 |
+
min_cfg_coef (float, optional): End coefficient of classifier free guidance annealing. Defaults to 1.0.
|
74 |
+
decoding_steps (list of n_q ints, optional): The number of iterative decoding steps,
|
75 |
+
for each of the n_q RVQ codebooks.
|
76 |
+
span_arrangement (str, optional): Use either non-overlapping spans ('nonoverlap')
|
77 |
+
or overlapping spans ('stride1') in the masking scheme.
|
78 |
+
"""
|
79 |
+
self.generation_params = {
|
80 |
+
'use_sampling': use_sampling,
|
81 |
+
'temp': temperature,
|
82 |
+
'top_k': top_k,
|
83 |
+
'top_p': top_p,
|
84 |
+
'max_cfg_coef': max_cfg_coef,
|
85 |
+
'min_cfg_coef': min_cfg_coef,
|
86 |
+
'decoding_steps': [int(s) for s in decoding_steps],
|
87 |
+
'span_arrangement': span_arrangement
|
88 |
+
}
|
audiocraft/models/multibanddiffusion.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Multi Band Diffusion models as described in
|
9 |
+
"From Discrete Tokens to High-Fidelity Audio Using Multi-Band Diffusion"
|
10 |
+
(paper link).
|
11 |
+
"""
|
12 |
+
|
13 |
+
import typing as tp
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import julius
|
17 |
+
|
18 |
+
from .unet import DiffusionUnet
|
19 |
+
from ..modules.diffusion_schedule import NoiseSchedule
|
20 |
+
from .encodec import CompressionModel
|
21 |
+
from ..solvers.compression import CompressionSolver
|
22 |
+
from .loaders import load_compression_model, load_diffusion_models
|
23 |
+
|
24 |
+
|
25 |
+
class DiffusionProcess:
|
26 |
+
"""Sampling for a diffusion Model.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
model (DiffusionUnet): Diffusion U-Net model.
|
30 |
+
noise_schedule (NoiseSchedule): Noise schedule for diffusion process.
|
31 |
+
"""
|
32 |
+
def __init__(self, model: DiffusionUnet, noise_schedule: NoiseSchedule) -> None:
|
33 |
+
self.model = model
|
34 |
+
self.schedule = noise_schedule
|
35 |
+
|
36 |
+
def generate(self, condition: torch.Tensor, initial_noise: torch.Tensor,
|
37 |
+
step_list: tp.Optional[tp.List[int]] = None):
|
38 |
+
"""Perform one diffusion process to generate one of the bands.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
condition (torch.Tensor): The embeddings from the compression model.
|
42 |
+
initial_noise (torch.Tensor): The initial noise to start the process.
|
43 |
+
"""
|
44 |
+
return self.schedule.generate_subsampled(model=self.model, initial=initial_noise, step_list=step_list,
|
45 |
+
condition=condition)
|
46 |
+
|
47 |
+
|
48 |
+
class MultiBandDiffusion:
|
49 |
+
"""Sample from multiple diffusion models.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
DPs (list of DiffusionProcess): Diffusion processes.
|
53 |
+
codec_model (CompressionModel): Underlying compression model used to obtain discrete tokens.
|
54 |
+
"""
|
55 |
+
def __init__(self, DPs: tp.List[DiffusionProcess], codec_model: CompressionModel) -> None:
|
56 |
+
self.DPs = DPs
|
57 |
+
self.codec_model = codec_model
|
58 |
+
self.device = next(self.codec_model.parameters()).device
|
59 |
+
|
60 |
+
@property
|
61 |
+
def sample_rate(self) -> int:
|
62 |
+
return self.codec_model.sample_rate
|
63 |
+
|
64 |
+
@staticmethod
|
65 |
+
def get_mbd_musicgen(device=None):
|
66 |
+
"""Load our diffusion models trained for MusicGen."""
|
67 |
+
if device is None:
|
68 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
69 |
+
path = 'facebook/multiband-diffusion'
|
70 |
+
filename = 'mbd_musicgen_32khz.th'
|
71 |
+
name = 'facebook/musicgen-small'
|
72 |
+
codec_model = load_compression_model(name, device=device)
|
73 |
+
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
|
74 |
+
DPs = []
|
75 |
+
for i in range(len(models)):
|
76 |
+
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
|
77 |
+
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
|
78 |
+
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
|
79 |
+
|
80 |
+
@staticmethod
|
81 |
+
def get_mbd_24khz(bw: float = 3.0,
|
82 |
+
device: tp.Optional[tp.Union[torch.device, str]] = None,
|
83 |
+
n_q: tp.Optional[int] = None):
|
84 |
+
"""Get the pretrained Models for MultibandDiffusion.
|
85 |
+
|
86 |
+
Args:
|
87 |
+
bw (float): Bandwidth of the compression model.
|
88 |
+
device (torch.device or str, optional): Device on which the models are loaded.
|
89 |
+
n_q (int, optional): Number of quantizers to use within the compression model.
|
90 |
+
"""
|
91 |
+
if device is None:
|
92 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
93 |
+
assert bw in [1.5, 3.0, 6.0], f"bandwidth {bw} not available"
|
94 |
+
if n_q is not None:
|
95 |
+
assert n_q in [2, 4, 8]
|
96 |
+
assert {1.5: 2, 3.0: 4, 6.0: 8}[bw] == n_q, \
|
97 |
+
f"bandwidth and number of codebooks missmatch to use n_q = {n_q} bw should be {n_q * (1.5 / 2)}"
|
98 |
+
n_q = {1.5: 2, 3.0: 4, 6.0: 8}[bw]
|
99 |
+
codec_model = CompressionSolver.model_from_checkpoint(
|
100 |
+
'//pretrained/facebook/encodec_24khz', device=device)
|
101 |
+
codec_model.set_num_codebooks(n_q)
|
102 |
+
codec_model = codec_model.to(device)
|
103 |
+
path = 'facebook/multiband-diffusion'
|
104 |
+
filename = f'mbd_comp_{n_q}.pt'
|
105 |
+
models, processors, cfgs = load_diffusion_models(path, filename=filename, device=device)
|
106 |
+
DPs = []
|
107 |
+
for i in range(len(models)):
|
108 |
+
schedule = NoiseSchedule(**cfgs[i].schedule, sample_processor=processors[i], device=device)
|
109 |
+
DPs.append(DiffusionProcess(model=models[i], noise_schedule=schedule))
|
110 |
+
return MultiBandDiffusion(DPs=DPs, codec_model=codec_model)
|
111 |
+
|
112 |
+
@torch.no_grad()
|
113 |
+
def get_condition(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
114 |
+
"""Get the conditioning (i.e. latent representations of the compression model) from a waveform.
|
115 |
+
Args:
|
116 |
+
wav (torch.Tensor): The audio that we want to extract the conditioning from.
|
117 |
+
sample_rate (int): Sample rate of the audio."""
|
118 |
+
if sample_rate != self.sample_rate:
|
119 |
+
wav = julius.resample_frac(wav, sample_rate, self.sample_rate)
|
120 |
+
codes, scale = self.codec_model.encode(wav)
|
121 |
+
assert scale is None, "Scaled compression models not supported."
|
122 |
+
emb = self.get_emb(codes)
|
123 |
+
return emb
|
124 |
+
|
125 |
+
@torch.no_grad()
|
126 |
+
def get_emb(self, codes: torch.Tensor):
|
127 |
+
"""Get latent representation from the discrete codes.
|
128 |
+
Args:
|
129 |
+
codes (torch.Tensor): Discrete tokens."""
|
130 |
+
emb = self.codec_model.decode_latent(codes)
|
131 |
+
return emb
|
132 |
+
|
133 |
+
def generate(self, emb: torch.Tensor, size: tp.Optional[torch.Size] = None,
|
134 |
+
step_list: tp.Optional[tp.List[int]] = None):
|
135 |
+
"""Generate waveform audio from the latent embeddings of the compression model.
|
136 |
+
Args:
|
137 |
+
emb (torch.Tensor): Conditioning embeddings
|
138 |
+
size (None, torch.Size): Size of the output
|
139 |
+
if None this is computed from the typical upsampling of the model.
|
140 |
+
step_list (list[int], optional): list of Markov chain steps, defaults to 50 linearly spaced step.
|
141 |
+
"""
|
142 |
+
if size is None:
|
143 |
+
upsampling = int(self.codec_model.sample_rate / self.codec_model.frame_rate)
|
144 |
+
size = torch.Size([emb.size(0), self.codec_model.channels, emb.size(-1) * upsampling])
|
145 |
+
assert size[0] == emb.size(0)
|
146 |
+
out = torch.zeros(size).to(self.device)
|
147 |
+
for DP in self.DPs:
|
148 |
+
out += DP.generate(condition=emb, step_list=step_list, initial_noise=torch.randn_like(out))
|
149 |
+
return out
|
150 |
+
|
151 |
+
def re_eq(self, wav: torch.Tensor, ref: torch.Tensor, n_bands: int = 32, strictness: float = 1):
|
152 |
+
"""Match the eq to the encodec output by matching the standard deviation of some frequency bands.
|
153 |
+
Args:
|
154 |
+
wav (torch.Tensor): Audio to equalize.
|
155 |
+
ref (torch.Tensor): Reference audio from which we match the spectrogram.
|
156 |
+
n_bands (int): Number of bands of the eq.
|
157 |
+
strictness (float): How strict the matching. 0 is no matching, 1 is exact matching.
|
158 |
+
"""
|
159 |
+
split = julius.SplitBands(n_bands=n_bands, sample_rate=self.codec_model.sample_rate).to(wav.device)
|
160 |
+
bands = split(wav)
|
161 |
+
bands_ref = split(ref)
|
162 |
+
out = torch.zeros_like(ref)
|
163 |
+
for i in range(n_bands):
|
164 |
+
out += bands[i] * (bands_ref[i].std() / bands[i].std()) ** strictness
|
165 |
+
return out
|
166 |
+
|
167 |
+
def regenerate(self, wav: torch.Tensor, sample_rate: int):
|
168 |
+
"""Regenerate a waveform through compression and diffusion regeneration.
|
169 |
+
Args:
|
170 |
+
wav (torch.Tensor): Original 'ground truth' audio.
|
171 |
+
sample_rate (int): Sample rate of the input (and output) wav.
|
172 |
+
"""
|
173 |
+
if sample_rate != self.codec_model.sample_rate:
|
174 |
+
wav = julius.resample_frac(wav, sample_rate, self.codec_model.sample_rate)
|
175 |
+
emb = self.get_condition(wav, sample_rate=self.codec_model.sample_rate)
|
176 |
+
size = wav.size()
|
177 |
+
out = self.generate(emb, size=size)
|
178 |
+
if sample_rate != self.codec_model.sample_rate:
|
179 |
+
out = julius.resample_frac(out, self.codec_model.sample_rate, sample_rate)
|
180 |
+
return out
|
181 |
+
|
182 |
+
def tokens_to_wav(self, tokens: torch.Tensor, n_bands: int = 32):
|
183 |
+
"""Generate Waveform audio with diffusion from the discrete codes.
|
184 |
+
Args:
|
185 |
+
tokens (torch.Tensor): Discrete codes.
|
186 |
+
n_bands (int): Bands for the eq matching.
|
187 |
+
"""
|
188 |
+
wav_encodec = self.codec_model.decode(tokens)
|
189 |
+
condition = self.get_emb(tokens)
|
190 |
+
wav_diffusion = self.generate(emb=condition, size=wav_encodec.size())
|
191 |
+
return self.re_eq(wav=wav_diffusion, ref=wav_encodec, n_bands=n_bands)
|
audiocraft/models/musicgen.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Main model for using MusicGen. This will combine all the required components
|
9 |
+
and provide easy access to the generation API.
|
10 |
+
"""
|
11 |
+
|
12 |
+
import typing as tp
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from .encodec import CompressionModel
|
18 |
+
from .genmodel import BaseGenModel
|
19 |
+
from .lm import LMModel
|
20 |
+
from .builders import get_debug_compression_model, get_debug_lm_model
|
21 |
+
from .loaders import load_compression_model, load_lm_model
|
22 |
+
from ..data.audio_utils import convert_audio
|
23 |
+
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
24 |
+
|
25 |
+
|
26 |
+
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
27 |
+
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
28 |
+
|
29 |
+
|
30 |
+
# backward compatible names mapping
|
31 |
+
_HF_MODEL_CHECKPOINTS_MAP = {
|
32 |
+
"small": "facebook/musicgen-small",
|
33 |
+
"medium": "facebook/musicgen-medium",
|
34 |
+
"large": "facebook/musicgen-large",
|
35 |
+
"melody": "facebook/musicgen-melody",
|
36 |
+
}
|
37 |
+
|
38 |
+
|
39 |
+
class MusicGen(BaseGenModel):
|
40 |
+
"""MusicGen main model with convenient generation API.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
name (str): name of the model.
|
44 |
+
compression_model (CompressionModel): Compression model
|
45 |
+
used to map audio to invertible discrete representations.
|
46 |
+
lm (LMModel): Language model over discrete representations.
|
47 |
+
max_duration (float, optional): maximum duration the model can produce,
|
48 |
+
otherwise, inferred from the training params.
|
49 |
+
"""
|
50 |
+
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
51 |
+
max_duration: tp.Optional[float] = None):
|
52 |
+
super().__init__(name, compression_model, lm, max_duration)
|
53 |
+
self.set_generation_params(duration=15) # default duration
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def get_pretrained(name: str = 'facebook/musicgen-melody', device=None):
|
57 |
+
"""Return pretrained model, we provide four models:
|
58 |
+
- facebook/musicgen-small (300M), text to music,
|
59 |
+
# see: https://huggingface.co/facebook/musicgen-small
|
60 |
+
- facebook/musicgen-medium (1.5B), text to music,
|
61 |
+
# see: https://huggingface.co/facebook/musicgen-medium
|
62 |
+
- facebook/musicgen-melody (1.5B) text to music and text+melody to music,
|
63 |
+
# see: https://huggingface.co/facebook/musicgen-melody
|
64 |
+
- facebook/musicgen-large (3.3B), text to music,
|
65 |
+
# see: https://huggingface.co/facebook/musicgen-large
|
66 |
+
"""
|
67 |
+
if device is None:
|
68 |
+
if torch.cuda.device_count():
|
69 |
+
device = 'cuda'
|
70 |
+
else:
|
71 |
+
device = 'cpu'
|
72 |
+
|
73 |
+
if name == 'debug':
|
74 |
+
# used only for unit tests
|
75 |
+
compression_model = get_debug_compression_model(device)
|
76 |
+
lm = get_debug_lm_model(device)
|
77 |
+
return MusicGen(name, compression_model, lm, max_duration=30)
|
78 |
+
|
79 |
+
if name in _HF_MODEL_CHECKPOINTS_MAP:
|
80 |
+
warnings.warn(
|
81 |
+
"MusicGen pretrained model relying on deprecated checkpoint mapping. " +
|
82 |
+
f"Please use full pre-trained id instead: facebook/musicgen-{name}")
|
83 |
+
name = _HF_MODEL_CHECKPOINTS_MAP[name]
|
84 |
+
|
85 |
+
lm = load_lm_model(name, device=device)
|
86 |
+
compression_model = load_compression_model(name, device=device)
|
87 |
+
if 'self_wav' in lm.condition_provider.conditioners:
|
88 |
+
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
89 |
+
lm.condition_provider.conditioners['self_wav']._use_masking = False
|
90 |
+
|
91 |
+
return MusicGen(name, compression_model, lm)
|
92 |
+
|
93 |
+
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
94 |
+
top_p: float = 0.0, temperature: float = 1.0,
|
95 |
+
duration: float = 30.0, cfg_coef: float = 3.0,
|
96 |
+
two_step_cfg: bool = False, extend_stride: float = 18):
|
97 |
+
"""Set the generation parameters for MusicGen.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
101 |
+
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
102 |
+
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
103 |
+
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
104 |
+
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
105 |
+
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
106 |
+
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
107 |
+
instead of batching together the two. This has some impact on how things
|
108 |
+
are padded but seems to have little impact in practice.
|
109 |
+
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
110 |
+
should we extend the audio each time. Larger values will mean less context is
|
111 |
+
preserved, and shorter value will require extra computations.
|
112 |
+
"""
|
113 |
+
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
114 |
+
self.extend_stride = extend_stride
|
115 |
+
self.duration = duration
|
116 |
+
self.generation_params = {
|
117 |
+
'use_sampling': use_sampling,
|
118 |
+
'temp': temperature,
|
119 |
+
'top_k': top_k,
|
120 |
+
'top_p': top_p,
|
121 |
+
'cfg_coef': cfg_coef,
|
122 |
+
'two_step_cfg': two_step_cfg,
|
123 |
+
}
|
124 |
+
|
125 |
+
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
126 |
+
melody_sample_rate: int, progress: bool = False,
|
127 |
+
return_tokens: bool = False) -> tp.Union[torch.Tensor,
|
128 |
+
tp.Tuple[torch.Tensor, torch.Tensor]]:
|
129 |
+
"""Generate samples conditioned on text and melody.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
133 |
+
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
134 |
+
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
135 |
+
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
136 |
+
a list of [C, T] tensors.
|
137 |
+
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
138 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
139 |
+
"""
|
140 |
+
if isinstance(melody_wavs, torch.Tensor):
|
141 |
+
if melody_wavs.dim() == 2:
|
142 |
+
melody_wavs = melody_wavs[None]
|
143 |
+
if melody_wavs.dim() != 3:
|
144 |
+
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
145 |
+
melody_wavs = list(melody_wavs)
|
146 |
+
else:
|
147 |
+
for melody in melody_wavs:
|
148 |
+
if melody is not None:
|
149 |
+
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
150 |
+
|
151 |
+
melody_wavs = [
|
152 |
+
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
|
153 |
+
if wav is not None else None
|
154 |
+
for wav in melody_wavs]
|
155 |
+
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
156 |
+
melody_wavs=melody_wavs)
|
157 |
+
assert prompt_tokens is None
|
158 |
+
tokens = self._generate_tokens(attributes, prompt_tokens, progress)
|
159 |
+
if return_tokens:
|
160 |
+
return self.generate_audio(tokens), tokens
|
161 |
+
return self.generate_audio(tokens)
|
162 |
+
|
163 |
+
@torch.no_grad()
|
164 |
+
def _prepare_tokens_and_attributes(
|
165 |
+
self,
|
166 |
+
descriptions: tp.Sequence[tp.Optional[str]],
|
167 |
+
prompt: tp.Optional[torch.Tensor],
|
168 |
+
melody_wavs: tp.Optional[MelodyList] = None,
|
169 |
+
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
170 |
+
"""Prepare model inputs.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
descriptions (list of str): A list of strings used as text conditioning.
|
174 |
+
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
175 |
+
melody_wavs (torch.Tensor, optional): A batch of waveforms
|
176 |
+
used as melody conditioning. Defaults to None.
|
177 |
+
"""
|
178 |
+
attributes = [
|
179 |
+
ConditioningAttributes(text={'description': description})
|
180 |
+
for description in descriptions]
|
181 |
+
|
182 |
+
if melody_wavs is None:
|
183 |
+
for attr in attributes:
|
184 |
+
attr.wav['self_wav'] = WavCondition(
|
185 |
+
torch.zeros((1, 1, 1), device=self.device),
|
186 |
+
torch.tensor([0], device=self.device),
|
187 |
+
sample_rate=[self.sample_rate],
|
188 |
+
path=[None])
|
189 |
+
else:
|
190 |
+
if 'self_wav' not in self.lm.condition_provider.conditioners:
|
191 |
+
raise RuntimeError("This model doesn't support melody conditioning. "
|
192 |
+
"Use the `melody` model.")
|
193 |
+
assert len(melody_wavs) == len(descriptions), \
|
194 |
+
f"number of melody wavs must match number of descriptions! " \
|
195 |
+
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
|
196 |
+
for attr, melody in zip(attributes, melody_wavs):
|
197 |
+
if melody is None:
|
198 |
+
attr.wav['self_wav'] = WavCondition(
|
199 |
+
torch.zeros((1, 1, 1), device=self.device),
|
200 |
+
torch.tensor([0], device=self.device),
|
201 |
+
sample_rate=[self.sample_rate],
|
202 |
+
path=[None])
|
203 |
+
else:
|
204 |
+
attr.wav['self_wav'] = WavCondition(
|
205 |
+
melody[None].to(device=self.device),
|
206 |
+
torch.tensor([melody.shape[-1]], device=self.device),
|
207 |
+
sample_rate=[self.sample_rate],
|
208 |
+
path=[None],
|
209 |
+
)
|
210 |
+
|
211 |
+
if prompt is not None:
|
212 |
+
if descriptions is not None:
|
213 |
+
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
214 |
+
prompt = prompt.to(self.device)
|
215 |
+
prompt_tokens, scale = self.compression_model.encode(prompt)
|
216 |
+
assert scale is None
|
217 |
+
else:
|
218 |
+
prompt_tokens = None
|
219 |
+
return attributes, prompt_tokens
|
220 |
+
|
221 |
+
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
222 |
+
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
223 |
+
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
attributes (list of ConditioningAttributes): Conditions used for generation (text/melody).
|
227 |
+
prompt_tokens (torch.Tensor, optional): Audio prompt used for continuation.
|
228 |
+
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
229 |
+
Returns:
|
230 |
+
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
231 |
+
"""
|
232 |
+
total_gen_len = int(self.duration * self.frame_rate)
|
233 |
+
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
234 |
+
current_gen_offset: int = 0
|
235 |
+
|
236 |
+
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
237 |
+
generated_tokens += current_gen_offset
|
238 |
+
if self._progress_callback is not None:
|
239 |
+
# Note that total_gen_len might be quite wrong depending on the
|
240 |
+
# codebook pattern used, but with delay it is almost accurate.
|
241 |
+
self._progress_callback(generated_tokens, tokens_to_generate)
|
242 |
+
else:
|
243 |
+
print(f'{generated_tokens: 6d} / {tokens_to_generate: 6d}', end='\r')
|
244 |
+
|
245 |
+
if prompt_tokens is not None:
|
246 |
+
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
247 |
+
"Prompt is longer than audio to generate"
|
248 |
+
|
249 |
+
callback = None
|
250 |
+
if progress:
|
251 |
+
callback = _progress_callback
|
252 |
+
|
253 |
+
if self.duration <= self.max_duration:
|
254 |
+
# generate by sampling from LM, simple case.
|
255 |
+
with self.autocast:
|
256 |
+
gen_tokens = self.lm.generate(
|
257 |
+
prompt_tokens, attributes,
|
258 |
+
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
259 |
+
|
260 |
+
else:
|
261 |
+
# now this gets a bit messier, we need to handle prompts,
|
262 |
+
# melody conditioning etc.
|
263 |
+
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
264 |
+
all_tokens = []
|
265 |
+
if prompt_tokens is None:
|
266 |
+
prompt_length = 0
|
267 |
+
else:
|
268 |
+
all_tokens.append(prompt_tokens)
|
269 |
+
prompt_length = prompt_tokens.shape[-1]
|
270 |
+
|
271 |
+
assert self.extend_stride is not None, "Stride should be defined to generate beyond max_duration"
|
272 |
+
assert self.extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
273 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
274 |
+
|
275 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
276 |
+
time_offset = current_gen_offset / self.frame_rate
|
277 |
+
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
278 |
+
max_gen_len = int(chunk_duration * self.frame_rate)
|
279 |
+
for attr, ref_wav in zip(attributes, ref_wavs):
|
280 |
+
wav_length = ref_wav.length.item()
|
281 |
+
if wav_length == 0:
|
282 |
+
continue
|
283 |
+
# We will extend the wav periodically if it not long enough.
|
284 |
+
# we have to do it here rather than in conditioners.py as otherwise
|
285 |
+
# we wouldn't have the full wav.
|
286 |
+
initial_position = int(time_offset * self.sample_rate)
|
287 |
+
wav_target_length = int(self.max_duration * self.sample_rate)
|
288 |
+
positions = torch.arange(initial_position,
|
289 |
+
initial_position + wav_target_length, device=self.device)
|
290 |
+
attr.wav['self_wav'] = WavCondition(
|
291 |
+
ref_wav[0][..., positions % wav_length],
|
292 |
+
torch.full_like(ref_wav[1], wav_target_length),
|
293 |
+
[self.sample_rate] * ref_wav[0].size(0),
|
294 |
+
[None], [0.])
|
295 |
+
with self.autocast:
|
296 |
+
gen_tokens = self.lm.generate(
|
297 |
+
prompt_tokens, attributes,
|
298 |
+
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
299 |
+
if prompt_tokens is None:
|
300 |
+
all_tokens.append(gen_tokens)
|
301 |
+
else:
|
302 |
+
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
303 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
304 |
+
prompt_length = prompt_tokens.shape[-1]
|
305 |
+
current_gen_offset += stride_tokens
|
306 |
+
|
307 |
+
gen_tokens = torch.cat(all_tokens, dim=-1)
|
308 |
+
return gen_tokens
|
audiocraft/models/unet.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
"""
|
8 |
+
Pytorch Unet Module used for diffusion.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from audiocraft.modules.transformer import StreamingTransformer, create_sin_embedding
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class Output:
|
22 |
+
sample: torch.Tensor
|
23 |
+
|
24 |
+
|
25 |
+
def get_model(cfg, channels: int, side: int, num_steps: int):
|
26 |
+
if cfg.model == 'unet':
|
27 |
+
return DiffusionUnet(
|
28 |
+
chin=channels, num_steps=num_steps, **cfg.diffusion_unet)
|
29 |
+
else:
|
30 |
+
raise RuntimeError('Not Implemented')
|
31 |
+
|
32 |
+
|
33 |
+
class ResBlock(nn.Module):
|
34 |
+
def __init__(self, channels: int, kernel: int = 3, norm_groups: int = 4,
|
35 |
+
dilation: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
36 |
+
dropout: float = 0.):
|
37 |
+
super().__init__()
|
38 |
+
stride = 1
|
39 |
+
padding = dilation * (kernel - stride) // 2
|
40 |
+
Conv = nn.Conv1d
|
41 |
+
Drop = nn.Dropout1d
|
42 |
+
self.norm1 = nn.GroupNorm(norm_groups, channels)
|
43 |
+
self.conv1 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
44 |
+
self.activation1 = activation()
|
45 |
+
self.dropout1 = Drop(dropout)
|
46 |
+
|
47 |
+
self.norm2 = nn.GroupNorm(norm_groups, channels)
|
48 |
+
self.conv2 = Conv(channels, channels, kernel, 1, padding, dilation=dilation)
|
49 |
+
self.activation2 = activation()
|
50 |
+
self.dropout2 = Drop(dropout)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
h = self.dropout1(self.conv1(self.activation1(self.norm1(x))))
|
54 |
+
h = self.dropout2(self.conv2(self.activation2(self.norm2(h))))
|
55 |
+
return x + h
|
56 |
+
|
57 |
+
|
58 |
+
class DecoderLayer(nn.Module):
|
59 |
+
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
60 |
+
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
61 |
+
dropout: float = 0.):
|
62 |
+
super().__init__()
|
63 |
+
padding = (kernel - stride) // 2
|
64 |
+
self.res_blocks = nn.Sequential(
|
65 |
+
*[ResBlock(chin, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
66 |
+
for idx in range(res_blocks)])
|
67 |
+
self.norm = nn.GroupNorm(norm_groups, chin)
|
68 |
+
ConvTr = nn.ConvTranspose1d
|
69 |
+
self.convtr = ConvTr(chin, chout, kernel, stride, padding, bias=False)
|
70 |
+
self.activation = activation()
|
71 |
+
|
72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
73 |
+
x = self.res_blocks(x)
|
74 |
+
x = self.norm(x)
|
75 |
+
x = self.activation(x)
|
76 |
+
x = self.convtr(x)
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class EncoderLayer(nn.Module):
|
81 |
+
def __init__(self, chin: int, chout: int, kernel: int = 4, stride: int = 2,
|
82 |
+
norm_groups: int = 4, res_blocks: int = 1, activation: tp.Type[nn.Module] = nn.ReLU,
|
83 |
+
dropout: float = 0.):
|
84 |
+
super().__init__()
|
85 |
+
padding = (kernel - stride) // 2
|
86 |
+
Conv = nn.Conv1d
|
87 |
+
self.conv = Conv(chin, chout, kernel, stride, padding, bias=False)
|
88 |
+
self.norm = nn.GroupNorm(norm_groups, chout)
|
89 |
+
self.activation = activation()
|
90 |
+
self.res_blocks = nn.Sequential(
|
91 |
+
*[ResBlock(chout, norm_groups=norm_groups, dilation=2**idx, dropout=dropout)
|
92 |
+
for idx in range(res_blocks)])
|
93 |
+
|
94 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
95 |
+
B, C, T = x.shape
|
96 |
+
stride, = self.conv.stride
|
97 |
+
pad = (stride - (T % stride)) % stride
|
98 |
+
x = F.pad(x, (0, pad))
|
99 |
+
|
100 |
+
x = self.conv(x)
|
101 |
+
x = self.norm(x)
|
102 |
+
x = self.activation(x)
|
103 |
+
x = self.res_blocks(x)
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class BLSTM(nn.Module):
|
108 |
+
"""BiLSTM with same hidden units as input dim.
|
109 |
+
"""
|
110 |
+
def __init__(self, dim, layers=2):
|
111 |
+
super().__init__()
|
112 |
+
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
113 |
+
self.linear = nn.Linear(2 * dim, dim)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
x = x.permute(2, 0, 1)
|
117 |
+
x = self.lstm(x)[0]
|
118 |
+
x = self.linear(x)
|
119 |
+
x = x.permute(1, 2, 0)
|
120 |
+
return x
|
121 |
+
|
122 |
+
|
123 |
+
class DiffusionUnet(nn.Module):
|
124 |
+
def __init__(self, chin: int = 3, hidden: int = 24, depth: int = 3, growth: float = 2.,
|
125 |
+
max_channels: int = 10_000, num_steps: int = 1000, emb_all_layers=False, cross_attention: bool = False,
|
126 |
+
bilstm: bool = False, transformer: bool = False,
|
127 |
+
codec_dim: tp.Optional[int] = None, **kwargs):
|
128 |
+
super().__init__()
|
129 |
+
self.encoders = nn.ModuleList()
|
130 |
+
self.decoders = nn.ModuleList()
|
131 |
+
self.embeddings: tp.Optional[nn.ModuleList] = None
|
132 |
+
self.embedding = nn.Embedding(num_steps, hidden)
|
133 |
+
if emb_all_layers:
|
134 |
+
self.embeddings = nn.ModuleList()
|
135 |
+
self.condition_embedding: tp.Optional[nn.Module] = None
|
136 |
+
for d in range(depth):
|
137 |
+
encoder = EncoderLayer(chin, hidden, **kwargs)
|
138 |
+
decoder = DecoderLayer(hidden, chin, **kwargs)
|
139 |
+
self.encoders.append(encoder)
|
140 |
+
self.decoders.insert(0, decoder)
|
141 |
+
if emb_all_layers and d > 0:
|
142 |
+
assert self.embeddings is not None
|
143 |
+
self.embeddings.append(nn.Embedding(num_steps, hidden))
|
144 |
+
chin = hidden
|
145 |
+
hidden = min(int(chin * growth), max_channels)
|
146 |
+
self.bilstm: tp.Optional[nn.Module]
|
147 |
+
if bilstm:
|
148 |
+
self.bilstm = BLSTM(chin)
|
149 |
+
else:
|
150 |
+
self.bilstm = None
|
151 |
+
self.use_transformer = transformer
|
152 |
+
self.cross_attention = False
|
153 |
+
if transformer:
|
154 |
+
self.cross_attention = cross_attention
|
155 |
+
self.transformer = StreamingTransformer(chin, 8, 6, bias_ff=False, bias_attn=False,
|
156 |
+
cross_attention=cross_attention)
|
157 |
+
|
158 |
+
self.use_codec = False
|
159 |
+
if codec_dim is not None:
|
160 |
+
self.conv_codec = nn.Conv1d(codec_dim, chin, 1)
|
161 |
+
self.use_codec = True
|
162 |
+
|
163 |
+
def forward(self, x: torch.Tensor, step: tp.Union[int, torch.Tensor], condition: tp.Optional[torch.Tensor] = None):
|
164 |
+
skips = []
|
165 |
+
bs = x.size(0)
|
166 |
+
z = x
|
167 |
+
view_args = [1]
|
168 |
+
if type(step) is torch.Tensor:
|
169 |
+
step_tensor = step
|
170 |
+
else:
|
171 |
+
step_tensor = torch.tensor([step], device=x.device, dtype=torch.long).expand(bs)
|
172 |
+
|
173 |
+
for idx, encoder in enumerate(self.encoders):
|
174 |
+
z = encoder(z)
|
175 |
+
if idx == 0:
|
176 |
+
z = z + self.embedding(step_tensor).view(bs, -1, *view_args).expand_as(z)
|
177 |
+
elif self.embeddings is not None:
|
178 |
+
z = z + self.embeddings[idx - 1](step_tensor).view(bs, -1, *view_args).expand_as(z)
|
179 |
+
|
180 |
+
skips.append(z)
|
181 |
+
|
182 |
+
if self.use_codec: # insert condition in the bottleneck
|
183 |
+
assert condition is not None, "Model defined for conditionnal generation"
|
184 |
+
condition_emb = self.conv_codec(condition) # reshape to the bottleneck dim
|
185 |
+
assert condition_emb.size(-1) <= 2 * z.size(-1), \
|
186 |
+
f"You are downsampling the conditionning with factor >=2 : {condition_emb.size(-1)=} and {z.size(-1)=}"
|
187 |
+
if not self.cross_attention:
|
188 |
+
|
189 |
+
condition_emb = torch.nn.functional.interpolate(condition_emb, z.size(-1))
|
190 |
+
assert z.size() == condition_emb.size()
|
191 |
+
z += condition_emb
|
192 |
+
cross_attention_src = None
|
193 |
+
else:
|
194 |
+
cross_attention_src = condition_emb.permute(0, 2, 1) # B, T, C
|
195 |
+
B, T, C = cross_attention_src.shape
|
196 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
197 |
+
pos_emb = create_sin_embedding(positions, C, max_period=10_000, dtype=cross_attention_src.dtype)
|
198 |
+
cross_attention_src = cross_attention_src + pos_emb
|
199 |
+
if self.use_transformer:
|
200 |
+
z = self.transformer(z.permute(0, 2, 1), cross_attention_src=cross_attention_src).permute(0, 2, 1)
|
201 |
+
else:
|
202 |
+
if self.bilstm is None:
|
203 |
+
z = torch.zeros_like(z)
|
204 |
+
else:
|
205 |
+
z = self.bilstm(z)
|
206 |
+
|
207 |
+
for decoder in self.decoders:
|
208 |
+
s = skips.pop(-1)
|
209 |
+
z = z[:, :, :s.shape[2]]
|
210 |
+
z = z + s
|
211 |
+
z = decoder(z)
|
212 |
+
|
213 |
+
z = z[:, :, :x.shape[2]]
|
214 |
+
return Output(z)
|