|
from dataclasses import dataclass |
|
import torch |
|
from tqdm.auto import trange |
|
import typing as tp |
|
from einops import rearrange |
|
from torch import nn |
|
|
|
from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config |
|
from .factory import create_pretransform_from_config |
|
from .lm_backbone import AudioLMBackbone, XTransformersAudioLMBackbone, ContinuousTransformerAudioLMBackbone |
|
from .pretransforms import Pretransform, AutoencoderPretransform, PretrainedDACPretransform, AudiocraftCompressionPretransform |
|
from .utils import multinomial, sample_top_k, sample_top_p |
|
|
|
from .codebook_patterns import ( |
|
CodebooksPatternProvider, |
|
DelayedPatternProvider, |
|
MusicLMPattern, |
|
ParallelPatternProvider, |
|
UnrolledPatternProvider |
|
) |
|
|
|
|
|
|
|
|
|
@dataclass |
|
class LMOutput: |
|
|
|
|
|
logits: torch.Tensor |
|
mask: torch.Tensor |
|
|
|
|
|
|
|
class AudioLanguageModel(nn.Module): |
|
def __init__( |
|
self, |
|
pattern_provider: CodebooksPatternProvider, |
|
backbone: AudioLMBackbone, |
|
num_quantizers: int, |
|
codebook_size: int |
|
): |
|
super().__init__() |
|
|
|
self.pattern_provider = pattern_provider |
|
self.backbone = backbone |
|
self.num_quantizers = num_quantizers |
|
self.codebook_size = codebook_size |
|
|
|
self.masked_token_id = codebook_size |
|
|
|
|
|
|
|
self.embeds = nn.ModuleList([nn.Embedding(codebook_size + 1, backbone.embed_dim) for _ in range(num_quantizers)]) |
|
|
|
|
|
self.quantizer_heads = nn.ModuleList([ |
|
nn.Linear(backbone.embed_dim, codebook_size) for _ in range(num_quantizers) |
|
]) |
|
|
|
def forward(self, |
|
sequence: torch.Tensor, |
|
prepend_cond=None, |
|
prepend_cond_mask=None, |
|
cross_attn_cond=None, |
|
**kwargs |
|
): |
|
|
|
|
|
batch, num_quantizers, seq_len = sequence.shape |
|
|
|
assert num_quantizers == self.num_quantizers, "Number of quantizers in sequence must match number of quantizers in model" |
|
|
|
backbone_input = sum([self.embeds[i](sequence[:, i]) for i in range(num_quantizers)]) |
|
|
|
dtype = next(self.parameters()).dtype |
|
|
|
if cross_attn_cond is not None: |
|
cross_attn_cond = cross_attn_cond.to(dtype) |
|
|
|
if prepend_cond is not None: |
|
prepend_cond = prepend_cond.to(dtype) |
|
|
|
if prepend_cond_mask is not None: |
|
prepend_cond_mask = prepend_cond_mask.to(dtype) |
|
|
|
backbone_input = backbone_input.to(dtype) |
|
|
|
output = self.backbone( |
|
backbone_input, |
|
cross_attn_cond=cross_attn_cond, |
|
prepend_cond=prepend_cond, |
|
prepend_cond_mask=prepend_cond_mask, |
|
**kwargs |
|
) |
|
|
|
|
|
logits = torch.stack([self.quantizer_heads[i](output) for i in range(num_quantizers)], dim=1) |
|
|
|
return logits |
|
|
|
def compute_logits( |
|
self, |
|
codes, |
|
**kwargs): |
|
""" |
|
Compute logits for a batch of codes, optionally conditioning on cross-attention and prepend conditioning |
|
Handles translation between input sequence and pattern-shifted sequence |
|
Only used during training |
|
""" |
|
|
|
batch, _, seq_len = codes.shape |
|
|
|
pattern = self.pattern_provider.get_pattern(seq_len) |
|
|
|
|
|
shifted_codes, _, _ = pattern.build_pattern_sequence( |
|
codes, |
|
self.masked_token_id, |
|
keep_only_valid_steps=True |
|
) |
|
|
|
|
|
logits = self(shifted_codes, **kwargs) |
|
|
|
|
|
logits = rearrange(logits, "b n s c -> b c n s") |
|
|
|
|
|
logits, _, logits_mask = pattern.revert_pattern_logits( |
|
logits, float('nan'), keep_only_valid_steps=True |
|
) |
|
|
|
logits = rearrange(logits, "b c n t -> b n t c") |
|
|
|
logits_mask = logits_mask[None, :, :].expand(batch, -1, -1) |
|
|
|
return LMOutput(logits=logits, mask=logits_mask) |
|
|
|
|
|
|
|
class AudioLanguageModelWrapper(nn.Module): |
|
def __init__( |
|
self, |
|
pretransform: Pretransform, |
|
lm: AudioLanguageModel, |
|
sample_rate: int, |
|
min_input_length: int, |
|
conditioner: MultiConditioner = None, |
|
cross_attn_cond_ids: tp.List[str] = [], |
|
prepend_cond_ids: tp.List[str] = [], |
|
global_cond_ids: tp.List[str] = [] |
|
): |
|
super().__init__() |
|
|
|
assert pretransform.is_discrete, "Pretransform must be discrete" |
|
self.pretransform = pretransform |
|
|
|
self.pretransform.requires_grad_(False) |
|
self.pretransform.eval() |
|
|
|
if isinstance(self.pretransform, AutoencoderPretransform): |
|
self.num_quantizers = self.pretransform.model.bottleneck.num_quantizers |
|
self.codebook_size = self.pretransform.model.bottleneck.codebook_size |
|
elif isinstance(self.pretransform, PretrainedDACPretransform): |
|
self.num_quantizers = self.pretransform.model.num_quantizers |
|
self.codebook_size = self.pretransform.model.codebook_size |
|
elif isinstance(self.pretransform, AudiocraftCompressionPretransform): |
|
self.num_quantizers = self.pretransform.num_quantizers |
|
self.codebook_size = self.pretransform.codebook_size |
|
else: |
|
raise NotImplementedError(f"Unrecognized pretransform type {type(self.pretransform)}") |
|
|
|
self.conditioner = conditioner |
|
|
|
self.lm = lm |
|
|
|
self.sample_rate = sample_rate |
|
self.min_input_length = min_input_length |
|
|
|
self.cross_attn_cond_ids = cross_attn_cond_ids |
|
self.prepend_cond_ids = prepend_cond_ids |
|
self.global_cond_ids = global_cond_ids |
|
|
|
def get_conditioning_inputs(self, cond: tp.Dict[str, tp.Any], negative=False): |
|
cross_attention_input = None |
|
prepend_cond = None |
|
prepend_cond_mask = None |
|
global_cond = None |
|
|
|
if len(self.cross_attn_cond_ids) > 0: |
|
|
|
|
|
cross_attention_input = torch.cat([cond[key][0] for key in self.cross_attn_cond_ids], dim=1) |
|
|
|
if len(self.prepend_cond_ids) > 0: |
|
|
|
|
|
prepend_cond = torch.cat([cond[key][0] for key in self.prepend_cond_ids], dim=1) |
|
prepend_cond_mask = torch.cat([cond[key][1] for key in self.prepend_cond_ids], dim=1) |
|
|
|
if len(self.global_cond_ids) > 0: |
|
|
|
|
|
global_cond = torch.cat([cond[key][0] for key in self.global_cond_ids], dim=-1) |
|
if len(global_cond.shape) == 3: |
|
global_cond = global_cond.squeeze(1) |
|
|
|
if negative: |
|
return { |
|
"negative_cross_attn_cond": cross_attention_input, |
|
"negative_prepend_cond": prepend_cond, |
|
"negative_prepend_cond_mask": prepend_cond_mask, |
|
"negative_global_cond": global_cond |
|
} |
|
else: |
|
return { |
|
"cross_attn_cond": cross_attention_input, |
|
"prepend_cond": prepend_cond, |
|
"prepend_cond_mask": prepend_cond_mask, |
|
"global_cond": global_cond |
|
} |
|
|
|
def compute_logits( |
|
self, |
|
codes, |
|
condition_tensors=None, |
|
cfg_dropout_prob=0.0, |
|
**kwargs |
|
): |
|
""" |
|
Compute logits for a batch of codes, and translates from conditioning inputs to model inputs |
|
Handles CFG dropout |
|
""" |
|
|
|
if condition_tensors is None: |
|
condition_tensors = {} |
|
|
|
conditioning_inputs = self.get_conditioning_inputs(condition_tensors) |
|
|
|
cross_attn_cond = conditioning_inputs["cross_attn_cond"] |
|
prepend_cond = conditioning_inputs["prepend_cond"] |
|
prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] |
|
global_cond = conditioning_inputs["global_cond"] |
|
|
|
if cfg_dropout_prob > 0.0: |
|
if cross_attn_cond is not None: |
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
|
dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool) |
|
cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond) |
|
|
|
if prepend_cond is not None: |
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
|
dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool) |
|
prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond) |
|
|
|
if global_cond is not None: |
|
null_embed = torch.zeros_like(global_cond, device=global_cond.device) |
|
dropout_mask = torch.bernoulli(torch.full((global_cond.shape[0], 1), cfg_dropout_prob, device=global_cond.device)).to(torch.bool) |
|
global_cond = torch.where(dropout_mask, null_embed, global_cond) |
|
|
|
return self.lm.compute_logits(codes, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) |
|
|
|
def _sample_next_token( |
|
self, |
|
sequence, |
|
conditioning_tensors=None, |
|
cross_attn_use_cfg=True, |
|
prepend_use_cfg=True, |
|
global_use_cfg=True, |
|
cfg_scale=1.0, |
|
top_k=250, |
|
top_p=0.0, |
|
temp=1.0, |
|
**kwargs |
|
): |
|
""" |
|
Sample the next token for a batch of codes, and translates from conditioning inputs to model inputs |
|
Handles CFG inference |
|
""" |
|
|
|
if conditioning_tensors is None: |
|
conditioning_tensors = {} |
|
|
|
conditioning_inputs = self.get_conditioning_inputs(conditioning_tensors) |
|
|
|
cross_attn_cond = conditioning_inputs["cross_attn_cond"] |
|
prepend_cond = conditioning_inputs["prepend_cond"] |
|
prepend_cond_mask = conditioning_inputs["prepend_cond_mask"] |
|
global_cond = conditioning_inputs["global_cond"] |
|
|
|
if cfg_scale != 1.0: |
|
|
|
|
|
sequence = torch.cat([sequence, sequence], dim=0) |
|
|
|
if cross_attn_cond is not None and cross_attn_use_cfg: |
|
null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device) |
|
|
|
cross_attn_cond = torch.cat([cross_attn_cond, null_embed], dim=0) |
|
|
|
if prepend_cond is not None and prepend_use_cfg: |
|
null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device) |
|
|
|
prepend_cond = torch.cat([prepend_cond, null_embed], dim=0) |
|
|
|
if prepend_cond_mask is not None: |
|
prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0) |
|
|
|
if global_cond is not None and global_use_cfg: |
|
null_embed = torch.zeros_like(global_cond, device=global_cond.device) |
|
|
|
global_cond = torch.cat([global_cond, null_embed], dim=0) |
|
|
|
logits = self.lm(sequence, cross_attn_cond=cross_attn_cond, prepend_cond=prepend_cond, prepend_cond_mask=prepend_cond_mask, global_cond=global_cond, **kwargs) |
|
|
|
if cfg_scale != 1.0: |
|
cond_logits, uncond_logits = logits.chunk(2, dim=0) |
|
|
|
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale |
|
|
|
logits = rearrange(logits, "b n s c -> b n c s") |
|
|
|
|
|
logits = logits[:, :, :, -1] |
|
|
|
|
|
|
|
if temp > 0: |
|
probs = torch.softmax(logits / temp, dim=-1) |
|
|
|
if top_p > 0.0: |
|
next_token = sample_top_p(probs, p=top_p) |
|
elif top_k > 0: |
|
next_token = sample_top_k(probs, k=top_k) |
|
else: |
|
next_token = multinomial(probs, num_samples=1) |
|
|
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
return next_token |
|
|
|
@torch.no_grad() |
|
def generate( |
|
self, |
|
max_gen_len: int = 256, |
|
batch_size: tp.Optional[int] = None, |
|
init_data: tp.Optional[torch.Tensor] = None, |
|
conditioning: tp.Optional[tp.Dict[str, tp.Any]] = None, |
|
conditioning_tensors: tp.Optional[tp.Dict[str, tp.Any]] = None, |
|
callback: tp.Optional[tp.Callable[[int, int], None]] = None, |
|
use_cache: bool = True, |
|
cfg_scale: float = 1.0, |
|
**kwargs |
|
): |
|
device = next(self.parameters()).device |
|
|
|
if conditioning_tensors is None and conditioning is not None: |
|
|
|
conditioning_tensors = self.conditioner(conditioning, device) |
|
|
|
|
|
possible_batch_sizes = [] |
|
|
|
if batch_size is not None: |
|
possible_batch_sizes.append(batch_size) |
|
elif init_data is not None: |
|
possible_batch_sizes.append(init_data.shape[0]) |
|
elif conditioning_tensors is not None: |
|
|
|
possible_batch_sizes.append(conditioning_tensors[list(conditioning_tensors.keys())[0]][0].shape[0]) |
|
else: |
|
possible_batch_sizes.append(1) |
|
|
|
assert [x == possible_batch_sizes[0] for x in possible_batch_sizes], "Batch size must be consistent across inputs" |
|
|
|
batch_size = possible_batch_sizes[0] |
|
|
|
if init_data is None: |
|
|
|
assert batch_size > 0 |
|
init_data = torch.zeros((batch_size, self.num_quantizers, 0), device=device, dtype=torch.long) |
|
|
|
batch_size, num_quantizers, seq_len = init_data.shape |
|
|
|
start_offset = seq_len |
|
assert start_offset < max_gen_len, "init data longer than max gen length" |
|
|
|
pattern = self.lm.pattern_provider.get_pattern(max_gen_len) |
|
|
|
unknown_token = -1 |
|
|
|
|
|
gen_codes = torch.full((batch_size, num_quantizers, max_gen_len), unknown_token, device=device, dtype=torch.long) |
|
gen_codes[:, :, :start_offset] = init_data |
|
|
|
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.lm.masked_token_id) |
|
|
|
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) |
|
assert start_offset_sequence is not None |
|
|
|
|
|
prev_offset = 0 |
|
gen_sequence_len = gen_sequence.shape[-1] |
|
|
|
|
|
if use_cache and self.lm.backbone.use_generation_cache: |
|
self.lm.backbone.reset_generation_cache(max_gen_len, batch_size if cfg_scale == 1.0 else batch_size * 2) |
|
|
|
for offset in trange(start_offset_sequence, gen_sequence_len): |
|
|
|
|
|
curr_sequence = gen_sequence[..., prev_offset:offset] |
|
|
|
next_token = self._sample_next_token( |
|
curr_sequence, |
|
conditioning_tensors=conditioning_tensors, |
|
use_cache=use_cache, |
|
cfg_scale=cfg_scale, |
|
**kwargs |
|
) |
|
|
|
valid_mask = mask[..., offset:offset+1].expand(batch_size, -1, -1) |
|
next_token[~valid_mask] = self.lm.masked_token_id |
|
|
|
|
|
gen_sequence[..., offset:offset+1] = torch.where( |
|
gen_sequence[..., offset:offset+1] == unknown_token, |
|
next_token, |
|
gen_sequence[..., offset:offset+1] |
|
) |
|
|
|
if use_cache and self.lm.backbone.use_generation_cache: |
|
|
|
prev_offset = offset |
|
|
|
self.lm.backbone.update_generation_cache(offset) |
|
|
|
if callback is not None: |
|
|
|
|
|
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) |
|
|
|
assert not (gen_sequence == unknown_token).any(), "Unknown tokens in generated sequence" |
|
|
|
out_codes, _, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) |
|
|
|
|
|
assert (out_codes[..., :max_gen_len] != unknown_token).all() |
|
assert (out_mask[..., :max_gen_len] == 1).all() |
|
|
|
|
|
|
|
return out_codes |
|
|
|
|
|
def generate_audio( |
|
self, |
|
**kwargs |
|
): |
|
""" |
|
Generate audio from a batch of codes |
|
""" |
|
|
|
codes = self.generate(**kwargs) |
|
|
|
audio = self.pretransform.decode_tokens(codes) |
|
|
|
return audio |
|
|
|
|
|
def create_audio_lm_from_config(config): |
|
model_config = config.get('model', None) |
|
assert model_config is not None, 'model config must be specified in config' |
|
|
|
sample_rate = config.get('sample_rate', None) |
|
assert sample_rate is not None, "Must specify sample_rate in config" |
|
|
|
lm_config = model_config.get('lm', None) |
|
assert lm_config is not None, 'lm config must be specified in model config' |
|
|
|
codebook_pattern = lm_config.get("codebook_pattern", "delay") |
|
|
|
pattern_providers = { |
|
'parallel': ParallelPatternProvider, |
|
'delay': DelayedPatternProvider, |
|
'unroll': UnrolledPatternProvider, |
|
'musiclm': MusicLMPattern, |
|
} |
|
|
|
pretransform_config = model_config.get("pretransform", None) |
|
|
|
pretransform = create_pretransform_from_config(pretransform_config, sample_rate) |
|
|
|
assert pretransform.is_discrete, "Pretransform must be discrete" |
|
|
|
min_input_length = pretransform.downsampling_ratio |
|
|
|
pattern_provider = pattern_providers[codebook_pattern](n_q=pretransform.num_quantizers) |
|
|
|
conditioning_config = model_config.get('conditioning', None) |
|
|
|
conditioner = None |
|
if conditioning_config is not None: |
|
conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config) |
|
|
|
cross_attn_cond_ids = lm_config.get('cross_attention_cond_ids', []) |
|
prepend_cond_ids = lm_config.get('prepend_cond_ids', []) |
|
global_cond_ids = lm_config.get('global_cond_ids', []) |
|
|
|
lm_type = lm_config.get("type", None) |
|
lm_model_config = lm_config.get("config", None) |
|
|
|
assert lm_type is not None, "Must specify lm type in lm config" |
|
assert lm_model_config is not None, "Must specify lm model config in lm config" |
|
|
|
if lm_type == "x-transformers": |
|
backbone = XTransformersAudioLMBackbone(**lm_model_config) |
|
elif lm_type == "continuous_transformer": |
|
backbone = ContinuousTransformerAudioLMBackbone(**lm_model_config) |
|
else: |
|
raise NotImplementedError(f"Unrecognized lm type {lm_type}") |
|
|
|
lm = AudioLanguageModel( |
|
pattern_provider=pattern_provider, |
|
backbone=backbone, |
|
num_quantizers=pretransform.num_quantizers, |
|
codebook_size=pretransform.codebook_size |
|
) |
|
|
|
model = AudioLanguageModelWrapper( |
|
pretransform=pretransform, |
|
lm=lm, |
|
conditioner=conditioner, |
|
sample_rate=sample_rate, |
|
min_input_length=min_input_length, |
|
cross_attn_cond_ids=cross_attn_cond_ids, |
|
prepend_cond_ids=prepend_cond_ids, |
|
global_cond_ids=global_cond_ids |
|
) |
|
|
|
return model |