|
from dataclasses import dataclass |
|
import logging |
|
import math |
|
import typing as tp |
|
import torch |
|
import torch.nn.functional as F |
|
from audiocraft.transformer import StreamingTransformer |
|
from dataclasses import dataclass |
|
from functools import partial |
|
from torch import nn |
|
from audiocraft.activations import get_activation_fn |
|
import numpy as np |
|
|
|
def _shift(x): |
|
|
|
print(x.shape, 'SHIFT\n= = = = = ') |
|
for i, _slice in enumerate(x): |
|
n = x.shape[2] |
|
offset = np.random.randint(.24 * n, max(1, .74 * n)) |
|
print(offset) |
|
x[i, :, :] = torch.roll(_slice, offset, dims=1) |
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
TextCondition = tp.Optional[str] |
|
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] |
|
|
|
ConditionTensors = tp.Dict[str, ConditionType] |
|
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] |
|
|
|
|
|
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): |
|
"""LM layer initialization. |
|
Inspired from xlformers: https://github.com/fairinternal/xlformers |
|
|
|
Args: |
|
method (str): Method name for init function. Valid options are: |
|
'gaussian', 'uniform'. |
|
input_dim (int): Input dimension of the initialized module. |
|
init_depth (int, optional): Optional init depth value used to rescale |
|
the standard deviation if defined. |
|
""" |
|
|
|
std = 1 / math.sqrt(input_dim) |
|
|
|
if init_depth is not None: |
|
std = std / math.sqrt(2 * init_depth) |
|
|
|
if method == 'gaussian': |
|
return partial( |
|
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std |
|
) |
|
elif method == 'uniform': |
|
bound = math.sqrt(3) * std |
|
return partial(torch.nn.init.uniform_, a=-bound, b=bound) |
|
else: |
|
raise ValueError("Unsupported layer initialization method") |
|
|
|
|
|
def init_layer(m: nn.Module, |
|
method: str, |
|
init_depth: tp.Optional[int] = None, |
|
zero_bias_init: bool = False): |
|
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules. |
|
|
|
Args: |
|
m (nn.Module): Module to initialize. |
|
method (str): Method name for the init function. |
|
init_depth (int, optional): Optional init depth value used to rescale |
|
the standard deviation if defined. |
|
zero_bias_init (bool): Whether to initialize the bias to 0 or not. |
|
""" |
|
if isinstance(m, nn.Linear): |
|
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) |
|
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
|
weight = m.weight.float() |
|
init_fn(weight) |
|
m.weight.data[:] = weight.half() |
|
else: |
|
init_fn(m.weight) |
|
if zero_bias_init and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Embedding): |
|
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) |
|
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
|
weight = m.weight.float() |
|
init_fn(weight) |
|
m.weight.data[:] = weight.half() |
|
else: |
|
init_fn(m.weight) |
|
|
|
|
|
class ScaledEmbedding(nn.Embedding): |
|
"""Boost learning rate for embeddings (with `scale`). |
|
""" |
|
def __init__(self, *args, lr=None, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.lr = lr |
|
|
|
def make_optim_group(self): |
|
group = {"params": list(self.parameters())} |
|
if self.lr is not None: |
|
group["lr"] = self.lr |
|
return group |
|
|
|
|
|
@dataclass |
|
class LMOutput: |
|
|
|
|
|
logits: torch.Tensor |
|
mask: torch.Tensor |
|
|
|
|
|
class LMModel(nn.Module): |
|
"""Transformer-based language model on multiple streams of codes. |
|
|
|
Args: |
|
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. |
|
condition_provider (MusicConditioningProvider): Conditioning provider from metadata. |
|
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. |
|
n_q (int): Number of parallel streams to model. |
|
card (int): Cardinality, vocabulary size. |
|
dim (int): Dimension of the transformer encoder. |
|
num_heads (int): Number of heads for the transformer encoder. |
|
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. |
|
norm (str): Normalization method. |
|
norm_first (bool): Use pre-norm instead of post-norm. |
|
emb_lr (float, optional): Embedding-specific learning rate. |
|
bias_proj (bool): Use bias for output projections. |
|
weight_init (str, optional): Method for weight initialization. |
|
depthwise_init (str, optional): Method for depthwise weight initialization. |
|
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. |
|
cfg_dropout (float): Classifier-free guidance dropout. |
|
cfg_coef (float): Classifier-free guidance coefficient. |
|
attribute_dropout (dict): Attribute dropout probabilities. |
|
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. |
|
**kwargs: Additional parameters for the transformer encoder. |
|
""" |
|
def __init__(self, |
|
pattern_provider, |
|
condition_provider, |
|
n_q: int = 8, |
|
card: int = 1024, |
|
dim: int = 128, |
|
num_heads: int = 8, |
|
hidden_scale: int = 4, |
|
norm: str = 'layer_norm', |
|
norm_first: bool = False, |
|
emb_lr: tp.Optional[float] = None, |
|
bias_proj: bool = True, |
|
weight_init: tp.Optional[str] = None, |
|
depthwise_init: tp.Optional[str] = None, |
|
zero_bias_init: bool = False, cfg_dropout: float = 0, |
|
cfg_coef: float = 1.0, |
|
two_step_cfg: bool = False, |
|
**kwargs): |
|
super().__init__() |
|
self.cfg_coef = cfg_coef |
|
self.condition_provider = condition_provider |
|
self.card = card |
|
self.n_draw = 2 |
|
embed_dim = self.card + 1 |
|
self.n_q = n_q |
|
self.dim = dim |
|
self.pattern_provider = pattern_provider |
|
self.two_step_cfg = two_step_cfg |
|
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) |
|
if 'activation' in kwargs: |
|
kwargs['activation'] = get_activation_fn(kwargs['activation']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
kwargs.pop('layer_scale') |
|
|
|
self.transformer = StreamingTransformer( |
|
d_model=dim, |
|
num_heads=num_heads, |
|
dim_feedforward=int(hidden_scale * dim), |
|
norm=norm, |
|
norm_first=norm_first, **kwargs) |
|
self.out_norm: tp.Optional[nn.Module] = None |
|
if norm_first: |
|
self.out_norm = nn.LayerNorm(dim, eps=1e-5) |
|
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) |
|
self._init_weights(weight_init, depthwise_init, zero_bias_init) |
|
self._fsdp: tp.Optional[nn.Module] |
|
self.__dict__['_fsdp'] = None |
|
|
|
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): |
|
"""Initialization of the transformer module weights. |
|
|
|
Args: |
|
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. |
|
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: |
|
'current' where the depth corresponds to the current layer index or 'global' where the total number |
|
of layer is used as depth. If not set, no depthwise initialization strategy is used. |
|
zero_bias_init (bool): Whether to initialize bias to zero or not. |
|
""" |
|
assert depthwise_init is None or depthwise_init in ['current', 'global'] |
|
assert depthwise_init is None or weight_init is not None, \ |
|
"If 'depthwise_init' is defined, a 'weight_init' method should be provided." |
|
assert not zero_bias_init or weight_init is not None, \ |
|
"If 'zero_bias_init', a 'weight_init' method should be provided" |
|
|
|
if weight_init is None: |
|
return |
|
|
|
for emb_layer in self.emb: |
|
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) |
|
|
|
for layer_idx, tr_layer in enumerate(self.transformer.layers): |
|
depth = None |
|
if depthwise_init == 'current': |
|
depth = layer_idx + 1 |
|
elif depthwise_init == 'global': |
|
depth = len(self.transformer.layers) |
|
init_fn = partial(init_layer, |
|
method=weight_init, |
|
init_depth=depth, |
|
zero_bias_init=zero_bias_init) |
|
tr_layer.apply(init_fn) |
|
|
|
for linear in self.linears: |
|
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) |
|
|
|
@property |
|
def special_token_id(self) -> int: |
|
return self.card |
|
|
|
|
|
|
|
def forward(self, |
|
sequence, |
|
condition_tensors=None, |
|
token_count=None): |
|
|
|
|
|
bs, _, _ = sequence.shape |
|
|
|
input_ = sum([self.emb[k](sequence[:, k]) for k in range(self.n_q)]) |
|
out = self.transformer(torch.cat([input_, input_], 0), |
|
cross_attention_src=condition_tensors, |
|
token_count=token_count) |
|
if self.out_norm: |
|
out = self.out_norm(out) |
|
|
|
logits = torch.stack([self.linears[k](out) for k in range(self.n_q)], dim=1) |
|
|
|
logits = 3 * logits[:bs, :, :, :] - 2 * logits[bs:, :, :, :] |
|
|
|
|
|
k = 250 |
|
p = torch.softmax(logits, dim=3) |
|
top_k_value, _ = torch.topk(p, k, dim=3) |
|
min_value_top_k = top_k_value[:, :, :, -1:] |
|
p *= (p >= min_value_top_k).float() |
|
p.div_(p.sum(dim=-1, keepdim=True)) |
|
|
|
|
|
|
|
p = p.reshape(bs * self.n_q, 2048) |
|
out = torch.multinomial(p, |
|
num_samples=self.n_draw, |
|
replacement=True) |
|
return out.reshape(bs, self.n_q, self.n_draw).transpose(1,2) |
|
|
|
@torch.no_grad() |
|
def generate(self, conditions = [], |
|
max_gen_len=256): |
|
|
|
|
|
tokenized = self.condition_provider.tokenize(conditions) |
|
|
|
|
|
cfg_conditions = self.condition_provider(tokenized) |
|
|
|
|
|
|
|
|
|
text_condition = cfg_conditions['description'][0] |
|
bs, _, _ = text_condition.shape |
|
text_condition = torch.cat( |
|
[ |
|
text_condition, |
|
torch.zeros_like(text_condition) |
|
], 0) |
|
|
|
|
|
|
|
|
|
pattern = self.pattern_provider.get_pattern(max_gen_len) |
|
gen_codes = torch.full((bs, |
|
self.n_q, |
|
max_gen_len), -1, dtype=torch.long, |
|
device=text_condition.device) |
|
|
|
gen_sequence, _, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) |
|
_, _, audiodur = gen_sequence.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = mask[None, None, :, :].repeat(bs, self.n_draw, 1, 1) |
|
gen_sequence = gen_sequence[:, None, :, :].repeat(1, self.n_draw, 1, 1) |
|
|
|
|
|
|
|
for offset in range(1, audiodur): |
|
|
|
|
|
next_token = self.forward(gen_sequence[:, 0, :, offset-1:offset], |
|
condition_tensors=text_condition, |
|
token_count=offset-1) |
|
|
|
|
|
|
|
|
|
|
|
m = mask[:, :, :, offset] |
|
next_token[~m] = self.special_token_id |
|
gen_sequence[:, :, :, offset] = torch.where( |
|
gen_sequence[:, :, :, offset] == -1, |
|
next_token, |
|
gen_sequence[:, :, :, offset] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
out_codes, _, _ = pattern.revert_pattern_sequence( |
|
gen_sequence.reshape(bs * self.n_draw, 4, audiodur), |
|
special_token=-1) |
|
|
|
_, _, new_len = out_codes.shape |
|
out_codes = out_codes.reshape(bs, self.n_draw, 4, new_len) |
|
out_codes = out_codes.transpose(1, 2).reshape(bs, 4, self.n_draw * new_len) |
|
print(out_codes.shape, 'o') |
|
for _ in range(7): |
|
out_codes = _shift(out_codes) |
|
|
|
|
|
for lay in self.transformer.layers: |
|
lay.self_attn.k_history = None |
|
lay.self_attn.v_history = None |
|
|
|
return out_codes |
|
|