Spaces:
Build error
Build error
Jonathan Fly
Use greedy sampling path when temp is 0.0 to avoid division by zero (#53)
8764625
unverified
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from dataclasses import dataclass | |
from functools import partial | |
import logging | |
import math | |
import typing as tp | |
import torch | |
from torch import nn | |
from ..utils import utils | |
from ..modules.streaming import StreamingModule, State | |
from ..modules.transformer import StreamingTransformer, create_norm_fn | |
from ..modules.conditioners import ( | |
ConditionFuser, | |
ClassifierFreeGuidanceDropout, | |
AttributeDropout, | |
ConditioningProvider, | |
ConditioningAttributes, | |
ConditionType, | |
) | |
from ..modules.codebooks_patterns import CodebooksPatternProvider | |
from ..modules.activations import get_activation_fn | |
logger = logging.getLogger(__name__) | |
ConditionTensors = tp.Dict[str, ConditionType] | |
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] | |
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): | |
"""LM layer initialization. | |
Inspired from xlformers: https://github.com/fairinternal/xlformers | |
Args: | |
method (str): Method name for init function. Valid options are: | |
'gaussian', 'uniform'. | |
input_dim (int): Input dimension of the initialized module. | |
init_depth (Optional[int]): Optional init depth value used to rescale | |
the standard deviation if defined. | |
""" | |
# Compute std | |
std = 1 / math.sqrt(input_dim) | |
# Rescale with depth | |
if init_depth is not None: | |
std = std / math.sqrt(2 * init_depth) | |
if method == 'gaussian': | |
return partial( | |
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std | |
) | |
elif method == 'uniform': | |
bound = math.sqrt(3) * std # ensure the standard deviation is `std` | |
return partial(torch.nn.init.uniform_, a=-bound, b=bound) | |
else: | |
raise ValueError("Unsupported layer initialization method") | |
def init_layer(m: nn.Module, | |
method: str, | |
init_depth: tp.Optional[int] = None, | |
zero_bias_init: bool = False): | |
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules. | |
Args: | |
m (nn.Module): Module to initialize. | |
method (str): Method name for the init function. | |
init_depth (Optional[int]): Optional init depth value used to rescale | |
the standard deviation if defined. | |
zero_bias_init (bool): Whether to initialize the bias to 0 or not. | |
""" | |
if isinstance(m, nn.Linear): | |
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) | |
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: | |
weight = m.weight.float() | |
init_fn(weight) | |
m.weight.data[:] = weight.half() | |
else: | |
init_fn(m.weight) | |
if zero_bias_init and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.Embedding): | |
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) | |
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: | |
weight = m.weight.float() | |
init_fn(weight) | |
m.weight.data[:] = weight.half() | |
else: | |
init_fn(m.weight) | |
class ScaledEmbedding(nn.Embedding): | |
"""Boost learning rate for embeddings (with `scale`). | |
""" | |
def __init__(self, *args, lr=None, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.lr = lr | |
def make_optim_group(self): | |
group = {"params": list(self.parameters())} | |
if self.lr is not None: | |
group["lr"] = self.lr | |
return group | |
class LMOutput: | |
# The logits are already re-aligned with the input codes | |
# hence no extra shift is required, e.g. when computing CE | |
logits: torch.Tensor # [B, K, T, card] | |
mask: torch.Tensor # [B, K, T] | |
class LMModel(StreamingModule): | |
"""Transformer-based language model on multiple streams of codes. | |
Args: | |
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. | |
condition_provider (MusicConditioningProvider): Conditioning provider from metadata. | |
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. | |
n_q (int): Number of parallel streams to model. | |
card (int): Cardinality, vocabulary size. | |
dim (int): Dimension of the transformer encoder. | |
num_heads (int): Number of heads for the transformer encoder. | |
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. | |
norm (str): Normalization method. | |
norm_first (bool): Use pre-norm instead of post-norm. | |
emb_lr (Optional[float]): Embedding-specific learning rate. | |
bias_proj (bool): Use bias for output projections. | |
weight_init (Optional[str]): Method for weight initialization. | |
depthwise_init (Optional[str]): Method for depthwise weight initialization. | |
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. | |
cfg_dropout (float): Classifier-free guidance dropout. | |
cfg_coef (float): Classifier-free guidance coefficient. | |
attribute_dropout (dict): Attribute dropout probabilities. | |
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. | |
**kwargs: Additional parameters for the transformer encoder. | |
""" | |
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, | |
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, | |
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, | |
emb_lr: tp.Optional[float] = None, bias_proj: bool = True, | |
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, | |
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, | |
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, | |
**kwargs): | |
super().__init__() | |
self.cfg_coef = cfg_coef | |
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) | |
self.att_dropout = AttributeDropout(p=attribute_dropout) | |
self.condition_provider = condition_provider | |
self.fuser = fuser | |
self.card = card | |
embed_dim = self.card + 1 | |
self.n_q = n_q | |
self.dim = dim | |
self.pattern_provider = pattern_provider | |
self.two_step_cfg = two_step_cfg | |
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) | |
if 'activation' in kwargs: | |
kwargs['activation'] = get_activation_fn(kwargs['activation']) | |
self.transformer = StreamingTransformer( | |
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), | |
norm=norm, norm_first=norm_first, **kwargs) | |
self.out_norm: tp.Optional[nn.Module] = None | |
if norm_first: | |
self.out_norm = create_norm_fn(norm, dim) | |
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) | |
self._init_weights(weight_init, depthwise_init, zero_bias_init) | |
self._fsdp: tp.Optional[nn.Module] | |
self.__dict__['_fsdp'] = None | |
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): | |
"""Initialization of the transformer module weights. | |
Args: | |
weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options. | |
depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid: | |
'current' where the depth corresponds to the current layer index or 'global' where the total number | |
of layer is used as depth. If not set, no depthwise initialization strategy is used. | |
zero_bias_init (bool): Whether to initalize bias to zero or not. | |
""" | |
assert depthwise_init is None or depthwise_init in ['current', 'global'] | |
assert depthwise_init is None or weight_init is not None, \ | |
"If 'depthwise_init' is defined, a 'weight_init' method should be provided." | |
assert not zero_bias_init or weight_init is not None, \ | |
"If 'zero_bias_init', a 'weight_init' method should be provided" | |
if weight_init is None: | |
return | |
for emb_layer in self.emb: | |
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) | |
for layer_idx, tr_layer in enumerate(self.transformer.layers): | |
depth = None | |
if depthwise_init == 'current': | |
depth = layer_idx + 1 | |
elif depthwise_init == 'global': | |
depth = len(self.transformer.layers) | |
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) | |
tr_layer.apply(init_fn) | |
for linear in self.linears: | |
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) | |
def special_token_id(self) -> int: | |
return self.card | |
def num_codebooks(self) -> int: | |
return self.n_q | |
def forward(self, sequence: torch.Tensor, | |
conditions: tp.List[ConditioningAttributes], | |
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor: | |
"""Apply language model on sequence and conditions. | |
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and | |
S the sequence steps, return the logits with shape [B, card, K, S]. | |
Args: | |
indices (torch.Tensor): indices of the codes to model. | |
conditions (list[ConditioningAttributes]): conditionings to use when modeling | |
the given codes. Note that when evaluating multiple time with the same conditioning | |
you should pre-compute those and pass them as `condition_tensors`. | |
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning | |
tensors, see `conditions`. | |
Returns: | |
torch.Tensor: Logits. | |
""" | |
B, K, S = sequence.shape | |
assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks' | |
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) | |
if condition_tensors is None: | |
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." | |
# apply dropout modules | |
conditions = self.cfg_dropout(conditions) | |
conditions = self.att_dropout(conditions) | |
tokenized = self.condition_provider.tokenize(conditions) | |
# encode conditions and fuse, both have a streaming cache to not recompute when generating. | |
condition_tensors = self.condition_provider(tokenized) | |
else: | |
assert not conditions, "Shouldn't pass both conditions and condition_tensors." | |
input_, cross_attention_input = self.fuser(input_, condition_tensors) | |
out = self.transformer(input_, cross_attention_src=cross_attention_input) | |
if self.out_norm: | |
out = self.out_norm(out) | |
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card] | |
# remove the prefix from the model outputs | |
if len(self.fuser.fuse2cond['prepend']) > 0: | |
logits = logits[:, :, -S:] | |
return logits # [B, K, S, card] | |
def compute_predictions( | |
self, codes: torch.Tensor, | |
conditions: tp.List[ConditioningAttributes], | |
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput: | |
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model | |
forward using the specified codes interleaving pattern. | |
Args: | |
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, | |
K the number of codebooks and T the number of timesteps. | |
conditions (list[ConditioningAttributes]): conditionings to use when modeling | |
the given codes. Note that when evaluating multiple time with the same conditioning | |
you should pre-compute those and pass them as `condition_tensors`. | |
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning | |
tensors, see `conditions`. | |
Returns: | |
LMOutput: Language model outputs | |
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, | |
i.e. the first item corresponds to logits to predict the first code, meaning that | |
no additional shifting of codes and logits is required. | |
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. | |
Given the specified interleaving strategies, parts of the logits and codes should | |
not be considered as valid predictions because of invalid context. | |
""" | |
B, K, T = codes.shape | |
codes = codes.contiguous() | |
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens | |
pattern = self.pattern_provider.get_pattern(T) | |
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( | |
codes, self.special_token_id, keep_only_valid_steps=True | |
) | |
# apply model on pattern sequence | |
model = self if self._fsdp is None else self._fsdp | |
logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card] | |
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card] | |
# and provide the corresponding mask over invalid positions of tokens | |
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S] | |
# note: we use nans as special token to make it obvious if we feed unexpected logits | |
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( | |
logits, float('nan'), keep_only_valid_steps=True | |
) | |
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card] | |
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T] | |
return LMOutput(logits, logits_mask) | |
def _sample_next_token(self, | |
sequence: torch.Tensor, | |
cfg_conditions: CFGConditions, | |
unconditional_state: State, | |
use_sampling: bool = False, | |
temp: float = 1.0, | |
top_k: int = 0, | |
top_p: float = 0.0, | |
cfg_coef: tp.Optional[float] = None) -> torch.Tensor: | |
"""Sample next token from the model given a sequence and a set of conditions. The model supports | |
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). | |
Args: | |
sequence (torch.Tensor): Current sequence of shape [B, K, S] | |
with K corresponding to the number of codebooks and S the number of sequence steps. | |
S = 1 in streaming mode, except for the first step that contains a bigger prompt. | |
condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used, | |
should be twice the batch size, being the concatenation of the conditions + null conditions. | |
use_sampling (bool): Whether to use a sampling strategy or not. | |
temp (float): Sampling temperature. | |
top_k (int): K for "top-k" sampling. | |
top_p (float): P for "top-p" sampling. | |
cfg_coef (float): classifier free guidance coefficient | |
Returns: | |
next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. | |
""" | |
B = sequence.shape[0] | |
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef | |
model = self if self._fsdp is None else self._fsdp | |
if self.two_step_cfg and cfg_conditions != {}: | |
assert isinstance(cfg_conditions, tuple) | |
condition_tensors, null_condition_tensors = cfg_conditions | |
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) | |
state = self.get_streaming_state() | |
self.set_streaming_state(unconditional_state) | |
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors) | |
unconditional_state.update(self.get_streaming_state()) | |
self.set_streaming_state(state) | |
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef | |
else: | |
assert isinstance(cfg_conditions, dict) | |
condition_tensors = cfg_conditions | |
if condition_tensors: | |
# Preparing for CFG, predicting both conditional and unconditional logits. | |
sequence = torch.cat([sequence, sequence], dim=0) | |
all_logits = model( | |
sequence, | |
conditions=[], condition_tensors=condition_tensors) | |
if condition_tensors: | |
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card] | |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef | |
else: | |
logits = all_logits | |
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T] | |
logits = logits[..., -1] # [B x K x card] | |
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error. | |
if use_sampling and temp > 0.0: | |
probs = torch.softmax(logits / temp, dim=-1) | |
if top_p > 0.0: | |
next_token = utils.sample_top_p(probs, p=top_p) | |
elif top_k > 0: | |
next_token = utils.sample_top_k(probs, k=top_k) | |
else: | |
next_token = utils.multinomial(probs, num_samples=1) | |
else: | |
next_token = torch.argmax(logits, dim=-1, keepdim=True) | |
return next_token | |
def generate(self, | |
prompt: tp.Optional[torch.Tensor] = None, | |
conditions: tp.List[ConditioningAttributes] = [], | |
num_samples: tp.Optional[int] = None, | |
max_gen_len: int = 256, | |
use_sampling: bool = True, | |
temp: float = 1.0, | |
top_k: int = 250, | |
top_p: float = 0.0, | |
cfg_coef: tp.Optional[float] = None, | |
two_step_cfg: bool = False, | |
remove_prompts: bool = False, | |
check: bool = False, | |
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor: | |
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can | |
be perform in a greedy fashion or using sampling with top K and top P strategies. | |
Args: | |
prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T]. | |
conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None. | |
num_samples (int or None): Number of samples to generate when no prompt and no conditions are given. | |
max_gen_len (int): Maximum generation length. | |
use_sampling (bool): Whether to use a sampling strategy or not. | |
temp (float): Sampling temperature. | |
top_k (int): K for "top-k" sampling. | |
top_p (float): P for "top-p" sampling. | |
remove_prompts (bool): Whether to remove prompts from generation or not. | |
Returns: | |
torch.Tensor: Generated tokens. | |
""" | |
assert not self.training, "generation shouldn't be used in training mode." | |
first_param = next(iter(self.parameters())) | |
device = first_param.device | |
# Checking all input shapes are consistents. | |
possible_num_samples = [] | |
if num_samples is not None: | |
possible_num_samples.append(num_samples) | |
elif prompt is not None: | |
possible_num_samples.append(prompt.shape[0]) | |
elif conditions: | |
possible_num_samples.append(len(conditions)) | |
else: | |
possible_num_samples.append(1) | |
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes" | |
num_samples = possible_num_samples[0] | |
# below we create set of conditions: one conditional and one unconditional | |
# to do that we merge the regular condition together with the null condition | |
# we then do 1 forward pass instead of 2. | |
# the reason for that is two-fold: | |
# 1. it is about x2 faster than doing 2 forward passes | |
# 2. avoid the streaming API treating the 2 passes as part of different time steps | |
# We also support doing two different passes, in particular to ensure that | |
# the padding structure is exactly the same between train anf test. | |
# With a batch size of 1, this can be slower though. | |
cfg_conditions: CFGConditions | |
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg | |
if conditions: | |
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) | |
if two_step_cfg: | |
cfg_conditions = ( | |
self.condition_provider(self.condition_provider.tokenize(conditions)), | |
self.condition_provider(self.condition_provider.tokenize(null_conditions)), | |
) | |
else: | |
conditions = conditions + null_conditions | |
tokenized = self.condition_provider.tokenize(conditions) | |
cfg_conditions = self.condition_provider(tokenized) | |
else: | |
cfg_conditions = {} | |
if prompt is None: | |
assert num_samples > 0 | |
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) | |
B, K, T = prompt.shape | |
start_offset = T | |
assert start_offset < max_gen_len | |
pattern = self.pattern_provider.get_pattern(max_gen_len) | |
# this token is used as default value for codes that are not generated yet | |
unknown_token = -1 | |
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence | |
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) | |
# filling the gen_codes with the prompt if needed | |
gen_codes[..., :start_offset] = prompt | |
# create the gen_sequence with proper interleaving from the pattern: [B, K, S] | |
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) | |
# retrieve the start_offset in the sequence: | |
# it is the first sequence step that contains the `start_offset` timestep | |
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) | |
assert start_offset_sequence is not None | |
with self.streaming(): | |
unconditional_state = self.get_streaming_state() | |
prev_offset = 0 | |
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S] | |
for offset in range(start_offset_sequence, gen_sequence_len): | |
# get current sequence (note that the streaming API is providing the caching over previous offsets) | |
curr_sequence = gen_sequence[..., prev_offset:offset] | |
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) | |
if check: | |
# check coherence between mask and sequence | |
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() | |
# should never happen as gen_sequence is filled progressively | |
assert not (curr_sequence == unknown_token).any() | |
# sample next token from the model, next token shape is [B, K, 1] | |
next_token = self._sample_next_token( | |
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, | |
cfg_coef=cfg_coef) | |
# ensure the tokens that should be masked are properly set to special_token_id | |
# as the model never output special_token_id | |
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) | |
next_token[~valid_mask] = self.special_token_id | |
# ensure we don't overwrite prompt tokens, we only write over unknown tokens | |
# (then mask tokens should be left as is as well, which is correct) | |
gen_sequence[..., offset:offset+1] = torch.where( | |
gen_sequence[..., offset:offset+1] == unknown_token, | |
next_token, gen_sequence[..., offset:offset+1] | |
) | |
prev_offset = offset | |
if callback is not None: | |
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) | |
unconditional_state.clear() | |
# ensure sequence has been entirely filled | |
assert not (gen_sequence == unknown_token).any() | |
# ensure gen_sequence pattern and mask are matching | |
# which means the gen_sequence is valid according to the pattern | |
assert ( | |
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) | |
).all() | |
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps | |
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) | |
# sanity checks over the returned codes and corresponding masks | |
assert (out_codes[..., :max_gen_len] != unknown_token).all() | |
assert (out_mask[..., :max_gen_len] == 1).all() | |
out_start_offset = start_offset if remove_prompts else 0 | |
out_codes = out_codes[..., out_start_offset:max_gen_len] | |
# ensure the returned codes are all valid | |
assert (out_codes >= 0).all() and (out_codes <= self.card).all() | |
return out_codes | |