|
|
|
import soundfile |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from dataclasses import dataclass, field |
|
from itertools import chain |
|
import logging |
|
import math |
|
from pathlib import Path |
|
import random |
|
import re |
|
import typing as tp |
|
import warnings |
|
import einops |
|
from num2words import num2words |
|
import spacy |
|
from transformers import RobertaTokenizer, T5EncoderModel, T5Tokenizer |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
from audiocraft.streaming import StreamingModule |
|
from audiocraft.transformer import create_sin_embedding |
|
from audiocraft.utils.audio_utils import convert_audio |
|
from audiocraft.utils.autocast import TorchAutocast |
|
from audiocraft.utils.cache import EmbeddingCache |
|
from audiocraft.utils.utils import collate, hash_trick, length_to_mask, load_clap_state_dict, warn_once |
|
from audiocraft.transformer import StreamingTransformer, create_norm_fn |
|
from dataclasses import dataclass |
|
from functools import partial |
|
import logging |
|
import math |
|
import typing as tp |
|
|
|
|
|
from torch import nn |
|
|
|
from audiocraft.utils import utils |
|
from audiocraft.codebooks_patterns import CodebooksPatternProvider |
|
from audiocraft.activations import get_activation_fn |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
TextCondition = tp.Optional[str] |
|
ConditionType = tp.Tuple[torch.Tensor, torch.Tensor] |
|
|
|
|
|
class WavCondition(tp.NamedTuple): |
|
wav: torch.Tensor |
|
length: torch.Tensor |
|
sample_rate: tp.List[int] |
|
path: tp.List[tp.Optional[str]] = [] |
|
seek_time: tp.List[tp.Optional[float]] = [] |
|
|
|
|
|
class JointEmbedCondition(tp.NamedTuple): |
|
wav: torch.Tensor |
|
text: tp.List[tp.Optional[str]] |
|
length: torch.Tensor |
|
sample_rate: tp.List[int] |
|
path: tp.List[tp.Optional[str]] = [] |
|
seek_time: tp.List[tp.Optional[float]] = [] |
|
|
|
|
|
@dataclass |
|
class ConditioningAttributes: |
|
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict) |
|
wav: tp.Dict[str, WavCondition] = field(default_factory=dict) |
|
joint_embed: tp.Dict[str, JointEmbedCondition] = field(default_factory=dict) |
|
|
|
def __getitem__(self, item): |
|
return getattr(self, item) |
|
|
|
@property |
|
def text_attributes(self): |
|
return self.text.keys() |
|
|
|
@property |
|
def wav_attributes(self): |
|
return self.wav.keys() |
|
|
|
@property |
|
def joint_embed_attributes(self): |
|
return self.joint_embed.keys() |
|
|
|
@property |
|
def attributes(self): |
|
return { |
|
"text": self.text_attributes, |
|
"wav": self.wav_attributes, |
|
"joint_embed": self.joint_embed_attributes, |
|
} |
|
|
|
def to_flat_dict(self): |
|
return { |
|
**{f"text.{k}": v for k, v in self.text.items()}, |
|
**{f"wav.{k}": v for k, v in self.wav.items()}, |
|
**{f"joint_embed.{k}": v for k, v in self.joint_embed.items()} |
|
} |
|
|
|
@classmethod |
|
def from_flat_dict(cls, x): |
|
out = cls() |
|
for k, v in x.items(): |
|
kind, att = k.split(".") |
|
out[kind][att] = v |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def nullify_condition(condition: ConditionType, dim: int = 1): |
|
"""Transform an input condition to a null condition. |
|
The way it is done by converting it to a single zero vector similarly |
|
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer. |
|
|
|
Args: |
|
condition (ConditionType): A tuple of condition and mask (tuple[torch.Tensor, torch.Tensor]) |
|
dim (int): The dimension that will be truncated (should be the time dimension) |
|
WARNING!: dim should not be the batch dimension! |
|
Returns: |
|
ConditionType: A tuple of null condition and mask |
|
""" |
|
assert dim != 0, "dim cannot be the batch dimension!" |
|
assert isinstance(condition, tuple) and \ |
|
isinstance(condition[0], torch.Tensor) and \ |
|
isinstance(condition[1], torch.Tensor), "'nullify_condition' got an unexpected input type!" |
|
cond, mask = condition |
|
B = cond.shape[0] |
|
last_dim = cond.dim() - 1 |
|
out = cond.transpose(dim, last_dim) |
|
out = 0. * out[..., :1] |
|
out = out.transpose(dim, last_dim) |
|
mask = torch.zeros((B, 1), device=out.device).int() |
|
assert cond.dim() == out.dim() |
|
return out, mask |
|
|
|
|
|
def nullify_wav(cond: WavCondition) -> WavCondition: |
|
"""Transform a WavCondition to a nullified WavCondition. |
|
It replaces the wav by a null tensor, forces its length to 0, and replaces metadata by dummy attributes. |
|
|
|
Args: |
|
cond (WavCondition): Wav condition with wav, tensor of shape [B, T]. |
|
Returns: |
|
WavCondition: Nullified wav condition. |
|
""" |
|
null_wav, _ = nullify_condition((cond.wav, torch.zeros_like(cond.wav)), dim=cond.wav.dim() - 1) |
|
return WavCondition( |
|
wav=null_wav, |
|
length=torch.tensor([0] * cond.wav.shape[0], device=cond.wav.device), |
|
sample_rate=cond.sample_rate, |
|
path=[None] * cond.wav.shape[0], |
|
seek_time=[None] * cond.wav.shape[0], |
|
) |
|
|
|
|
|
def nullify_joint_embed(embed: JointEmbedCondition) -> JointEmbedCondition: |
|
"""Nullify the joint embedding condition by replacing it by a null tensor, forcing its length to 0, |
|
and replacing metadata by dummy attributes. |
|
|
|
Args: |
|
cond (JointEmbedCondition): Joint embedding condition with wav and text, wav tensor of shape [B, C, T]. |
|
""" |
|
null_wav, _ = nullify_condition((embed.wav, torch.zeros_like(embed.wav)), dim=embed.wav.dim() - 1) |
|
return JointEmbedCondition( |
|
wav=null_wav, text=[None] * len(embed.text), |
|
length=torch.LongTensor([0]).to(embed.wav.device), |
|
sample_rate=embed.sample_rate, |
|
path=[None] * embed.wav.shape[0], |
|
seek_time=[0] * embed.wav.shape[0], |
|
) |
|
|
|
|
|
class Tokenizer: |
|
"""Base tokenizer implementation |
|
(in case we want to introduce more advances tokenizers in the future). |
|
""" |
|
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
raise NotImplementedError() |
|
|
|
|
|
class WhiteSpaceTokenizer(Tokenizer): |
|
"""This tokenizer should be used for natural language descriptions. |
|
For example: |
|
["he didn't, know he's going home.", 'shorter sentence'] => |
|
[[78, 62, 31, 4, 78, 25, 19, 34], |
|
[59, 77, 0, 0, 0, 0, 0, 0]] |
|
""" |
|
PUNCTUATION = "?:!.,;" |
|
|
|
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm", |
|
lemma: bool = True, stopwords: bool = True) -> None: |
|
self.n_bins = n_bins |
|
self.pad_idx = pad_idx |
|
self.lemma = lemma |
|
self.stopwords = stopwords |
|
try: |
|
self.nlp = spacy.load(language) |
|
except IOError: |
|
spacy.cli.download(language) |
|
self.nlp = spacy.load(language) |
|
|
|
@tp.no_type_check |
|
def __call__(self, texts: tp.List[tp.Optional[str]], |
|
return_text: bool = False) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
"""Take a list of strings and convert them to a tensor of indices. |
|
|
|
Args: |
|
texts (list[str]): List of strings. |
|
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False. |
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: |
|
- Indices of words in the LUT. |
|
- And a mask indicating where the padding tokens are |
|
""" |
|
output, lengths = [], [] |
|
texts = deepcopy(texts) |
|
for i, text in enumerate(texts): |
|
|
|
if text is None: |
|
output.append(torch.Tensor([self.pad_idx])) |
|
lengths.append(0) |
|
continue |
|
|
|
|
|
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) |
|
|
|
text = self.nlp(text) |
|
|
|
if self.stopwords: |
|
text = [w for w in text if not w.is_stop] |
|
|
|
text = [w for w in text if w.text not in self.PUNCTUATION] |
|
|
|
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] |
|
|
|
texts[i] = " ".join(text) |
|
lengths.append(len(text)) |
|
|
|
tokens = torch.Tensor([hash_trick(w, self.n_bins) for w in text]) |
|
output.append(tokens) |
|
|
|
mask = length_to_mask(torch.IntTensor(lengths)).int() |
|
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t() |
|
if return_text: |
|
return padded_output, mask, texts |
|
return padded_output, mask |
|
|
|
|
|
class NoopTokenizer(Tokenizer): |
|
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc. |
|
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split |
|
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will |
|
split it to ["Jeff", "Buckley"] and return an index per word. |
|
|
|
For example: |
|
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101] |
|
["Metal", "Rock", "Classical"] => [0, 223, 51] |
|
""" |
|
def __init__(self, n_bins: int, pad_idx: int = 0): |
|
self.n_bins = n_bins |
|
self.pad_idx = pad_idx |
|
|
|
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
output, lengths = [], [] |
|
for text in texts: |
|
|
|
if text is None: |
|
output.append(self.pad_idx) |
|
lengths.append(0) |
|
else: |
|
output.append(hash_trick(text, self.n_bins)) |
|
lengths.append(1) |
|
|
|
tokens = torch.LongTensor(output).unsqueeze(1) |
|
mask = length_to_mask(torch.IntTensor(lengths)).int() |
|
return tokens, mask |
|
|
|
|
|
class BaseConditioner(nn.Module): |
|
"""Base model for all conditioner modules. |
|
We allow the output dim to be different than the hidden dim for two reasons: |
|
1) keep our LUTs small when the vocab is large; |
|
2) make all condition dims consistent. |
|
|
|
Args: |
|
dim (int): Hidden dim of the model. |
|
output_dim (int): Output dim of the conditioner. |
|
""" |
|
def __init__(self, dim: int, output_dim: int): |
|
super().__init__() |
|
self.dim = dim |
|
self.output_dim = output_dim |
|
self.output_proj = nn.Linear(dim, output_dim) |
|
|
|
def tokenize(self, *args, **kwargs) -> tp.Any: |
|
"""Should be any part of the processing that will lead to a synchronization |
|
point, e.g. BPE tokenization with transfer to the GPU. |
|
|
|
The returned value will be saved and return later when calling forward(). |
|
""" |
|
raise NotImplementedError() |
|
|
|
def forward(self, inputs: tp.Any) -> ConditionType: |
|
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform). |
|
Outputs a ConditionType, after the input data was embedded as a dense vector. |
|
|
|
Returns: |
|
ConditionType: |
|
- A tensor of size [B, T, D] where B is the batch size, T is the length of the |
|
output embedding and D is the dimension of the embedding. |
|
- And a mask indicating where the padding tokens. |
|
""" |
|
raise NotImplementedError() |
|
|
|
|
|
class TextConditioner(BaseConditioner): |
|
... |
|
|
|
|
|
class LUTConditioner(TextConditioner): |
|
"""Lookup table TextConditioner. |
|
|
|
Args: |
|
n_bins (int): Number of bins. |
|
dim (int): Hidden dim of the model (text-encoder/LUT). |
|
output_dim (int): Output dim of the conditioner. |
|
tokenizer (str): Name of the tokenizer. |
|
pad_idx (int, optional): Index for padding token. Defaults to 0. |
|
""" |
|
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0): |
|
super().__init__(dim, output_dim) |
|
self.embed = nn.Embedding(n_bins, dim) |
|
self.tokenizer: Tokenizer |
|
if tokenizer == 'whitespace': |
|
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx) |
|
elif tokenizer == 'noop': |
|
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx) |
|
else: |
|
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.") |
|
|
|
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
device = self.embed.weight.device |
|
tokens, mask = self.tokenizer(x) |
|
tokens, mask = tokens.to(device), mask.to(device) |
|
return tokens, mask |
|
|
|
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType: |
|
tokens, mask = inputs |
|
embeds = self.embed(tokens) |
|
embeds = self.output_proj(embeds) |
|
embeds = (embeds * mask.unsqueeze(-1)) |
|
return embeds, mask |
|
|
|
|
|
class T5Conditioner(TextConditioner): |
|
"""T5-based TextConditioner. |
|
|
|
Args: |
|
name (str): Name of the T5 model. |
|
output_dim (int): Output dim of the conditioner. |
|
finetune (bool): Whether to fine-tune T5 at train time. |
|
device (str): Device for T5 Conditioner. |
|
autocast_dtype (tp.Optional[str], optional): Autocast dtype. |
|
word_dropout (float, optional): Word dropout probability. |
|
normalize_text (bool, optional): Whether to apply text normalization. |
|
""" |
|
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", |
|
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", |
|
"google/flan-t5-xl", "google/flan-t5-xxl"] |
|
MODELS_DIMS = { |
|
"t5-small": 512, |
|
"t5-base": 768, |
|
"t5-large": 1024, |
|
"t5-3b": 1024, |
|
"t5-11b": 1024, |
|
"google/flan-t5-small": 512, |
|
"google/flan-t5-base": 768, |
|
"google/flan-t5-large": 1024, |
|
"google/flan-t5-3b": 1024, |
|
"google/flan-t5-11b": 1024, |
|
} |
|
|
|
def __init__(self, name: str, output_dim: int, finetune: bool, device: str, |
|
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0., |
|
normalize_text: bool = False): |
|
assert name in self.MODELS, f"Unrecognized t5 model name (should in {self.MODELS})" |
|
super().__init__(self.MODELS_DIMS[name], output_dim) |
|
self.device = device |
|
self.name = name |
|
self.finetune = finetune |
|
self.word_dropout = word_dropout |
|
if autocast_dtype is None or self.device == 'cpu': |
|
self.autocast = TorchAutocast(enabled=False) |
|
if self.device != 'cpu': |
|
logger.warning("T5 has no autocast, this might lead to NaN") |
|
else: |
|
dtype = getattr(torch, autocast_dtype) |
|
assert isinstance(dtype, torch.dtype) |
|
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}") |
|
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) |
|
|
|
|
|
previous_level = logging.root.manager.disable |
|
logging.disable(logging.ERROR) |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
try: |
|
self.t5_tokenizer = T5Tokenizer.from_pretrained(name) |
|
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune) |
|
finally: |
|
logging.disable(previous_level) |
|
if finetune: |
|
self.t5 = t5 |
|
else: |
|
|
|
|
|
self.__dict__['t5'] = t5.to(device) |
|
|
|
self.normalize_text = normalize_text |
|
if normalize_text: |
|
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True) |
|
|
|
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]: |
|
|
|
entries: tp.List[str] = [xi if xi is not None else "" for xi in x] |
|
if self.normalize_text: |
|
_, _, entries = self.text_normalizer(entries, return_text=True) |
|
if self.word_dropout > 0. and self.training: |
|
new_entries = [] |
|
for entry in entries: |
|
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout] |
|
new_entries.append(" ".join(words)) |
|
entries = new_entries |
|
|
|
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""]) |
|
|
|
inputs = self.t5_tokenizer(entries, return_tensors='pt', padding=True).to(self.device) |
|
mask = inputs['attention_mask'] |
|
mask[empty_idx, :] = 0 |
|
return inputs |
|
|
|
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType: |
|
mask = inputs['attention_mask'] |
|
with torch.set_grad_enabled(self.finetune), self.autocast: |
|
embeds = self.t5(**inputs).last_hidden_state |
|
embeds = self.output_proj(embeds.to(self.output_proj.weight)) |
|
embeds = (embeds * mask.unsqueeze(-1)) |
|
return embeds, mask |
|
|
|
|
|
class WaveformConditioner(BaseConditioner): |
|
"""Base class for all conditioners that take a waveform as input. |
|
Classes that inherit must implement `_get_wav_embedding` that outputs |
|
a continuous tensor, and `_downsampling_factor` that returns the down-sampling |
|
factor of the embedding model. |
|
|
|
Args: |
|
dim (int): The internal representation dimension. |
|
output_dim (int): Output dimension. |
|
device (tp.Union[torch.device, str]): Device. |
|
""" |
|
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]): |
|
super().__init__(dim, output_dim) |
|
self.device = device |
|
|
|
self._use_masking = True |
|
|
|
def tokenize(self, x: WavCondition) -> WavCondition: |
|
wav, length, sample_rate, path, seek_time = x |
|
assert length is not None |
|
return WavCondition(wav.to(self.device), length.to(self.device), sample_rate, path, seek_time) |
|
|
|
def _get_wav_embedding(self, x: WavCondition) -> torch.Tensor: |
|
"""Gets as input a WavCondition and returns a dense embedding.""" |
|
raise NotImplementedError() |
|
|
|
def _downsampling_factor(self): |
|
"""Returns the downsampling factor of the embedding model.""" |
|
raise NotImplementedError() |
|
|
|
def forward(self, x: WavCondition) -> ConditionType: |
|
"""Extract condition embedding and mask from a waveform and its metadata. |
|
Args: |
|
x (WavCondition): Waveform condition containing raw waveform and metadata. |
|
Returns: |
|
ConditionType: a dense vector representing the conditioning along with its mask |
|
""" |
|
wav, lengths, *_ = x |
|
with torch.no_grad(): |
|
embeds = self._get_wav_embedding(x) |
|
embeds = embeds.to(self.output_proj.weight) |
|
embeds = self.output_proj(embeds) |
|
|
|
if lengths is not None and self._use_masking: |
|
lengths = lengths / self._downsampling_factor() |
|
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() |
|
else: |
|
mask = torch.ones_like(embeds[..., 0]) |
|
embeds = (embeds * mask.unsqueeze(-1)) |
|
return embeds, mask |
|
|
|
|
|
|
|
|
|
|
|
class JointEmbeddingConditioner(BaseConditioner): |
|
"""Joint embedding conditioning supporting both audio or text conditioning. |
|
|
|
Args: |
|
dim (int): Dimension. |
|
output_dim (int): Output dimension. |
|
device (str): Device. |
|
attribute (str): Attribute used by the conditioner. |
|
autocast_dtype (str): Autocast for the conditioner. |
|
quantize (bool): Whether to quantize the CLAP embedding. |
|
n_q (int): Number of residual quantizers (used if quantize is true). |
|
bins (int): Quantizers' codebooks size (used if quantize is true). |
|
kwargs: Additional parameters for residual vector quantizer. |
|
""" |
|
def __init__(self, dim: int, output_dim: int, device: str, attribute: str, |
|
autocast_dtype: tp.Optional[str] = 'float32', quantize: bool = True, |
|
n_q: int = 12, bins: int = 1024, **kwargs): |
|
super().__init__(dim=dim, output_dim=output_dim) |
|
self.device = device |
|
self.attribute = attribute |
|
if autocast_dtype is None or device == 'cpu': |
|
self.autocast = TorchAutocast(enabled=False) |
|
logger.warning("JointEmbeddingConditioner has no autocast, this might lead to NaN.") |
|
else: |
|
dtype = getattr(torch, autocast_dtype) |
|
assert isinstance(dtype, torch.dtype) |
|
logger.info(f"JointEmbeddingConditioner will be evaluated with autocast as {autocast_dtype}.") |
|
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype) |
|
|
|
self.quantizer=None |
|
if quantize: |
|
print('\n\n\n\nWANTS TO QUANTIZE on Inference\n\n\n\n') |
|
|
|
|
|
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
"""Get joint embedding in latent space from the inputs. |
|
|
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: Tensor for the latent embedding |
|
and corresponding empty indexes. |
|
""" |
|
raise NotImplementedError() |
|
|
|
def forward(self, x: JointEmbedCondition) -> ConditionType: |
|
with self.autocast: |
|
embed, empty_idx = self._get_embed(x) |
|
if self.quantizer is not None: |
|
embed = embed.view(-1, self.dim, 1) |
|
q_res = self.quantizer(embed, frame_rate=1) |
|
out_embed = q_res.x.view(-1, self.dim) |
|
else: |
|
out_embed = embed |
|
out_embed = self.output_proj(out_embed).view(-1, 1, self.output_dim) |
|
mask = torch.ones(*out_embed.shape[:2], device=out_embed.device) |
|
mask[empty_idx, :] = 0 |
|
out_embed = (out_embed * mask.unsqueeze(-1)) |
|
return out_embed, mask |
|
|
|
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: |
|
return x |
|
|
|
|
|
class CLAPEmbeddingConditioner(JointEmbeddingConditioner): |
|
"""Joint Embedding conditioner based on pre-trained CLAP model. |
|
|
|
This CLAP-based conditioner supports a caching mechanism |
|
over the computed embeddings for faster training. |
|
|
|
Args: |
|
dim (int): Dimension. |
|
output_dim (int): Output dimension. |
|
device (str): Device. |
|
attribute (str): Attribute used by the conditioner. |
|
quantize (bool): Whether to quantize the CLAP embedding. |
|
n_q (int): Number of residual quantizers (used if quantize is true). |
|
bins (int): Quantizers' codebooks size (used if quantize is true). |
|
checkpoint (str): Path to CLAP checkpoint. |
|
model_arch (str): CLAP model architecture. |
|
enable_fusion (bool): Enable fusion for CLAP model. |
|
sample_rate (int): Sample rate used by CLAP model. |
|
max_audio_length (float): Maximum audio length for CLAP model. |
|
audio_stride (float): Stride to use for getting a CLAP embedding on the full sequence. |
|
normalize (bool): Whether to normalize the CLAP embedding. |
|
text_p (float): Probability of using text representation instead of audio at train time. |
|
batch_size (Optional[int]): Batch size for CLAP embedding computation. |
|
autocast_dtype (str): Autocast for the conditioner. |
|
cache_path (Optional[str]): Path for pre-computed embeddings caching. |
|
kwargs: Additional parameters for residual vector quantizer. |
|
""" |
|
def __init__(self, dim: int, output_dim: int, device: str, attribute: str, |
|
quantize: bool, n_q: int, bins: int, checkpoint: tp.Union[str, Path], model_arch: str, |
|
enable_fusion: bool, sample_rate: int, max_audio_length: int, audio_stride: int, |
|
normalize: bool, text_p: bool, batch_size: tp.Optional[int] = None, |
|
autocast_dtype: tp.Optional[str] = 'float32', cache_path: tp.Optional[str] = None, **kwargs): |
|
try: |
|
import laion_clap |
|
except ImportError: |
|
raise ImportError("Please install CLAP to use the CLAPEmbeddingConditioner: 'pip install laion_clap'") |
|
warnings.warn("Sample rate for CLAP conditioner was fixed in version v1.1.0, (from 44.1 to 48 kHz). " |
|
"Please retrain all models.") |
|
checkpoint = AudioCraftEnvironment.resolve_reference_path(checkpoint) |
|
clap_tokenize = RobertaTokenizer.from_pretrained('roberta-base') |
|
clap_model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=model_arch) |
|
load_clap_state_dict(clap_model, checkpoint) |
|
clap_model.eval() |
|
clap_model.to(device) |
|
super().__init__(dim=dim, output_dim=output_dim, device=device, attribute=attribute, |
|
autocast_dtype=autocast_dtype, quantize=quantize, n_q=n_q, bins=bins, |
|
**kwargs) |
|
self.checkpoint = checkpoint |
|
self.enable_fusion = enable_fusion |
|
self.model_arch = model_arch |
|
self.clap: laion_clap.CLAP_Module |
|
self.clap_tokenize: RobertaTokenizer |
|
self.clap_sample_rate = sample_rate |
|
self.clap_max_frames = int(self.clap_sample_rate * max_audio_length) |
|
self.clap_stride = int(self.clap_sample_rate * audio_stride) |
|
self.batch_size = batch_size or 1 |
|
self.normalize = normalize |
|
self.text_p = text_p |
|
self.__dict__['clap_tokenize'] = clap_tokenize |
|
self.__dict__['clap'] = clap_model |
|
self.wav_cache, self.text_cache = None, None |
|
if cache_path is not None: |
|
self.wav_cache = EmbeddingCache(Path(cache_path) / 'wav', self.device, |
|
compute_embed_fn=self._get_wav_embedding_for_cache, |
|
extract_embed_fn=self._extract_wav_embedding_chunk) |
|
self.text_cache = EmbeddingCache(Path(cache_path) / 'text', self.device, |
|
compute_embed_fn=self._get_text_embedding_for_cache) |
|
|
|
def _tokenizer(self, texts: tp.Union[str, tp.List[str]]) -> dict: |
|
|
|
return self.clap_tokenize(texts, padding="max_length", truncation=True, max_length=77, return_tensors="pt") |
|
|
|
def _compute_text_embedding(self, text: tp.List[str]) -> torch.Tensor: |
|
"""Compute text embedding from CLAP model on a given a batch of text. |
|
|
|
Args: |
|
text (list[str]): List of text for the batch, with B items. |
|
Returns: |
|
torch.Tensor: CLAP embedding derived from text, of shape [B, 1, D], with D the CLAP embedding dimension. |
|
""" |
|
with torch.no_grad(): |
|
embed = self.clap.get_text_embedding(text, tokenizer=self._tokenizer, use_tensor=True) |
|
return embed.view(embed.size(0), 1, embed.size(-1)) |
|
|
|
def _get_text_embedding_for_cache(self, path: tp.Union[Path, str], |
|
x: JointEmbedCondition, idx: int) -> torch.Tensor: |
|
"""Get text embedding function for the cache.""" |
|
text = x.text[idx] |
|
text = text if text is not None else "" |
|
return self._compute_text_embedding([text])[0] |
|
|
|
def _preprocess_wav(self, wav: torch.Tensor, length: torch.Tensor, sample_rates: tp.List[int]) -> torch.Tensor: |
|
"""Preprocess wav to expected format by CLAP model. |
|
|
|
Args: |
|
wav (torch.Tensor): Audio wav, of shape [B, C, T]. |
|
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. |
|
sample_rates (list[int]): Sample rates for each sample in the batch |
|
Returns: |
|
torch.Tensor: Audio wav of shape [B, T]. |
|
""" |
|
assert wav.dim() == 3, "Expecting wav to be [B, C, T]" |
|
if sample_rates is not None: |
|
_wav = [] |
|
for i, audio in enumerate(wav): |
|
sr = sample_rates[i] |
|
audio = convert_audio(audio, from_rate=sr, to_rate=self.clap_sample_rate, to_channels=1) |
|
_wav.append(audio) |
|
wav = torch.stack(_wav, dim=0) |
|
wav = wav.mean(dim=1) |
|
return wav |
|
|
|
def _compute_wav_embedding(self, wav: torch.Tensor, length: torch.Tensor, |
|
sample_rates: tp.List[int], reduce_mean: bool = False) -> torch.Tensor: |
|
"""Compute audio wave embedding from CLAP model. |
|
|
|
Since CLAP operates on a fixed sequence length audio inputs and we need to process longer audio sequences, |
|
we calculate the wav embeddings on `clap_max_frames` windows with `clap_stride`-second stride and |
|
average the resulting embeddings. |
|
|
|
Args: |
|
wav (torch.Tensor): Audio wav, of shape [B, C, T]. |
|
length (torch.Tensor): Actual length of the audio for each item in the batch, of shape [B]. |
|
sample_rates (list[int]): Sample rates for each sample in the batch. |
|
reduce_mean (bool): Whether to get the average tensor. |
|
Returns: |
|
torch.Tensor: Audio embedding of shape [B, F, D], F being the number of chunks, D the dimension. |
|
""" |
|
with torch.no_grad(): |
|
wav = self._preprocess_wav(wav, length, sample_rates) |
|
B, T = wav.shape |
|
if T >= self.clap_max_frames: |
|
wav = wav.unfold(-1, self.clap_max_frames, self.clap_stride) |
|
else: |
|
wav = wav.view(-1, 1, T) |
|
wav = einops.rearrange(wav, 'b f t -> (b f) t') |
|
embed_list = [] |
|
for i in range(0, wav.size(0), self.batch_size): |
|
_wav = wav[i:i+self.batch_size, ...] |
|
_embed = self.clap.get_audio_embedding_from_data(_wav, use_tensor=True) |
|
embed_list.append(_embed) |
|
embed = torch.cat(embed_list, dim=0) |
|
embed = einops.rearrange(embed, '(b f) d -> b f d', b=B) |
|
if reduce_mean: |
|
embed = embed.mean(dim=1, keepdim=True) |
|
return embed |
|
|
|
def _get_wav_embedding_for_cache(self, path: tp.Union[str, Path], |
|
x: JointEmbedCondition, idx: int) -> torch.Tensor: |
|
"""Compute audio wave embedding for the cache. |
|
The embedding is computed on a given audio read from file. |
|
|
|
Args: |
|
path (str or Path): Path to the full audio file. |
|
Returns: |
|
torch.Tensor: Single-item tensor of shape [F, D], F being the number of chunks, D the dimension. |
|
""" |
|
wav, sr = soundfile.read(path) |
|
wav = wav.unsqueeze(0).to(self.device) |
|
wav_len = torch.LongTensor([wav.shape[-1]]).to(self.device) |
|
embed = self._compute_wav_embedding(wav, wav_len, [sr], reduce_mean=False) |
|
return embed.squeeze(0) |
|
|
|
def _extract_wav_embedding_chunk(self, full_embed: torch.Tensor, x: JointEmbedCondition, idx: int) -> torch.Tensor: |
|
"""Extract the chunk of embedding matching the seek_time and length from the full CLAP audio embedding. |
|
|
|
Args: |
|
full_embed (torch.Tensor): CLAP embedding computed on the full wave, of shape [F, D]. |
|
x (JointEmbedCondition): Joint embedding condition for the full batch. |
|
idx (int): Index considered for the given embedding to extract. |
|
Returns: |
|
torch.Tensor: Wav embedding averaged on sliding window, of shape [1, D]. |
|
""" |
|
sample_rate = x.sample_rate[idx] |
|
seek_time = x.seek_time[idx] |
|
seek_time = 0. if seek_time is None else seek_time |
|
clap_stride = int(self.clap_stride / self.clap_sample_rate) * sample_rate |
|
end_seek_time = seek_time + self.clap_max_frames / self.clap_sample_rate |
|
start_offset = int(seek_time * sample_rate // clap_stride) |
|
end_offset = int(end_seek_time * sample_rate // clap_stride) |
|
wav_embed = full_embed[start_offset:end_offset, ...] |
|
wav_embed = wav_embed.mean(dim=0, keepdim=True) |
|
return wav_embed.to(self.device) |
|
|
|
def _get_text_embedding(self, x: JointEmbedCondition) -> torch.Tensor: |
|
"""Get CLAP embedding from a batch of text descriptions.""" |
|
no_nullified_cond = x.wav.shape[-1] > 1 |
|
if self.text_cache is not None and no_nullified_cond: |
|
assert all(p is not None for p in x.path), "Cache requires all JointEmbedCondition paths to be provided" |
|
paths = [Path(p) for p in x.path if p is not None] |
|
embed = self.text_cache.get_embed_from_cache(paths, x) |
|
else: |
|
text = [xi if xi is not None else "" for xi in x.text] |
|
embed = self._compute_text_embedding(text) |
|
if self.normalize: |
|
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) |
|
return embed |
|
|
|
def _get_wav_embedding(self, x: JointEmbedCondition) -> torch.Tensor: |
|
"""Get CLAP embedding from a batch of audio tensors (and corresponding sample rates).""" |
|
no_undefined_paths = all(p is not None for p in x.path) |
|
no_nullified_cond = x.wav.shape[-1] > 1 |
|
if self.wav_cache is not None and no_undefined_paths and no_nullified_cond: |
|
paths = [Path(p) for p in x.path if p is not None] |
|
embed = self.wav_cache.get_embed_from_cache(paths, x) |
|
else: |
|
embed = self._compute_wav_embedding(x.wav, x.length, x.sample_rate, reduce_mean=True) |
|
if self.normalize: |
|
embed = torch.nn.functional.normalize(embed, p=2.0, dim=-1) |
|
return embed |
|
|
|
def tokenize(self, x: JointEmbedCondition) -> JointEmbedCondition: |
|
|
|
no_undefined_paths = all(p is not None for p in x.path) |
|
if self.wav_cache is not None and no_undefined_paths: |
|
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" |
|
paths = [Path(p) for p in x.path if p is not None] |
|
self.wav_cache.populate_embed_cache(paths, x) |
|
if self.text_cache is not None and no_undefined_paths: |
|
assert all([p is not None for p in x.path]), "Cache requires all JointEmbedCondition paths to be provided" |
|
paths = [Path(p) for p in x.path if p is not None] |
|
self.text_cache.populate_embed_cache(paths, x) |
|
return x |
|
|
|
def _get_embed(self, x: JointEmbedCondition) -> tp.Tuple[torch.Tensor, torch.Tensor]: |
|
"""Extract shared latent representation from either the wav or the text using CLAP.""" |
|
|
|
use_text_embed = random.random() < self.text_p |
|
if self.training and not use_text_embed: |
|
embed = self._get_wav_embedding(x) |
|
empty_idx = torch.LongTensor([]) |
|
else: |
|
embed = self._get_text_embedding(x) |
|
empty_idx = torch.LongTensor([i for i, xi in enumerate(x.text) if xi is None or xi == ""]) |
|
return embed, empty_idx |
|
|
|
|
|
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str) -> ConditioningAttributes: |
|
"""Utility function for nullifying an attribute inside an ConditioningAttributes object. |
|
If the condition is of type "wav", then nullify it using `nullify_condition` function. |
|
If the condition is of any other type, set its value to None. |
|
Works in-place. |
|
""" |
|
if condition_type not in ['text', 'wav', 'joint_embed']: |
|
raise ValueError( |
|
"dropout_condition got an unexpected condition type!" |
|
f" expected 'text', 'wav' or 'joint_embed' but got '{condition_type}'" |
|
) |
|
|
|
if condition not in getattr(sample, condition_type): |
|
raise ValueError( |
|
"dropout_condition received an unexpected condition!" |
|
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}" |
|
f" but got '{condition}' of type '{condition_type}'!" |
|
) |
|
|
|
if condition_type == 'wav': |
|
wav_cond = sample.wav[condition] |
|
sample.wav[condition] = nullify_wav(wav_cond) |
|
elif condition_type == 'joint_embed': |
|
embed = sample.joint_embed[condition] |
|
sample.joint_embed[condition] = nullify_joint_embed(embed) |
|
else: |
|
sample.text[condition] = None |
|
|
|
return sample |
|
|
|
|
|
class DropoutModule(nn.Module): |
|
"""Base module for all dropout modules.""" |
|
def __init__(self, seed: int = 1234): |
|
super().__init__() |
|
self.rng = torch.Generator() |
|
self.rng.manual_seed(seed) |
|
|
|
|
|
class AttributeDropout(DropoutModule): |
|
"""Dropout with a given probability per attribute. |
|
This is different from the behavior of ClassifierFreeGuidanceDropout as this allows for attributes |
|
to be dropped out separately. For example, "artist" can be dropped while "genre" remains. |
|
This is in contrast to ClassifierFreeGuidanceDropout where if "artist" is dropped "genre" |
|
must also be dropped. |
|
|
|
Args: |
|
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example: |
|
... |
|
"genre": 0.1, |
|
"artist": 0.5, |
|
"wav": 0.25, |
|
... |
|
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False. |
|
seed (int, optional): Random seed. |
|
""" |
|
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234): |
|
super().__init__(seed=seed) |
|
self.active_on_eval = active_on_eval |
|
|
|
self.p = {} |
|
for condition_type, probs in p.items(): |
|
self.p[condition_type] = defaultdict(lambda: 0, probs) |
|
|
|
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: |
|
""" |
|
Args: |
|
samples (list[ConditioningAttributes]): List of conditions. |
|
Returns: |
|
list[ConditioningAttributes]: List of conditions after certain attributes were set to None. |
|
""" |
|
if not self.training and not self.active_on_eval: |
|
return samples |
|
|
|
samples = deepcopy(samples) |
|
for condition_type, ps in self.p.items(): |
|
for condition, p in ps.items(): |
|
if torch.rand(1, generator=self.rng).item() < p: |
|
for sample in samples: |
|
dropout_condition(sample, condition_type, condition) |
|
return samples |
|
|
|
def __repr__(self): |
|
return f"AttributeDropout({dict(self.p)})" |
|
|
|
|
|
class ClassifierFreeGuidanceDropout(DropoutModule): |
|
"""Classifier Free Guidance dropout. |
|
All attributes are dropped with the same probability. |
|
|
|
Args: |
|
p (float): Probability to apply condition dropout during training. |
|
seed (int): Random seed. |
|
""" |
|
def __init__(self, p: float, seed: int = 1234): |
|
super().__init__(seed=seed) |
|
self.p = p |
|
|
|
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]: |
|
""" |
|
Args: |
|
samples (list[ConditioningAttributes]): List of conditions. |
|
Returns: |
|
list[ConditioningAttributes]: List of conditions after all attributes were set to None. |
|
""" |
|
if not self.training: |
|
return samples |
|
|
|
|
|
drop = torch.rand(1, generator=self.rng).item() < self.p |
|
if not drop: |
|
return samples |
|
|
|
|
|
samples = deepcopy(samples) |
|
for condition_type in ["wav", "text"]: |
|
for sample in samples: |
|
for condition in sample.attributes[condition_type]: |
|
dropout_condition(sample, condition_type, condition) |
|
return samples |
|
|
|
def __repr__(self): |
|
return f"ClassifierFreeGuidanceDropout(p={self.p})" |
|
|
|
|
|
class ConditioningProvider(nn.Module): |
|
"""Prepare and provide conditions given all the supported conditioners. |
|
|
|
Args: |
|
conditioners (dict): Dictionary of conditioners. |
|
device (torch.device or str, optional): Device for conditioners and output condition types. |
|
""" |
|
def __init__(self, conditioners: tp.Dict[str, BaseConditioner], device: tp.Union[torch.device, str] = "cpu"): |
|
super().__init__() |
|
self.device = device |
|
self.conditioners = nn.ModuleDict(conditioners) |
|
|
|
@property |
|
def joint_embed_conditions(self): |
|
return [m.attribute for m in self.conditioners.values() if isinstance(m, JointEmbeddingConditioner)] |
|
|
|
@property |
|
def has_joint_embed_conditions(self): |
|
return len(self.joint_embed_conditions) > 0 |
|
|
|
@property |
|
def text_conditions(self): |
|
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)] |
|
|
|
@property |
|
def wav_conditions(self): |
|
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)] |
|
|
|
@property |
|
def has_wav_condition(self): |
|
return len(self.wav_conditions) > 0 |
|
|
|
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]: |
|
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly. |
|
This should be called before starting any real GPU work to avoid synchronization points. |
|
This will return a dict matching conditioner names to their arbitrary tokenized representations. |
|
|
|
Args: |
|
inputs (list[ConditioningAttributes]): List of ConditioningAttributes objects containing |
|
text and wav conditions. |
|
""" |
|
assert all([isinstance(x, ConditioningAttributes) for x in inputs]), ( |
|
"Got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]", |
|
f" but types were {set([type(x) for x in inputs])}" |
|
) |
|
|
|
output = {} |
|
text = self._collate_text(inputs) |
|
wavs = self._collate_wavs(inputs) |
|
joint_embeds = self._collate_joint_embeds(inputs) |
|
|
|
assert set(text.keys() | wavs.keys() | joint_embeds.keys()).issubset(set(self.conditioners.keys())), ( |
|
f"Got an unexpected attribute! Expected {self.conditioners.keys()}, ", |
|
f"got {text.keys(), wavs.keys(), joint_embeds.keys()}" |
|
) |
|
|
|
for attribute, batch in chain(text.items(), wavs.items(), joint_embeds.items()): |
|
output[attribute] = self.conditioners[attribute].tokenize(batch) |
|
return output |
|
|
|
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]: |
|
"""Compute pairs of `(embedding, mask)` using the configured conditioners and the tokenized representations. |
|
The output is for example: |
|
{ |
|
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])), |
|
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])), |
|
... |
|
} |
|
|
|
Args: |
|
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`. |
|
""" |
|
output = {} |
|
for attribute, inputs in tokenized.items(): |
|
condition, mask = self.conditioners[attribute](inputs) |
|
output[attribute] = (condition, mask) |
|
return output |
|
|
|
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]: |
|
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys |
|
are the attributes and the values are the aggregated input per attribute. |
|
For example: |
|
Input: |
|
[ |
|
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...), |
|
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...), |
|
] |
|
Output: |
|
{ |
|
"genre": ["Rock", "Hip-hop"], |
|
"description": ["A rock song with a guitar solo", "A hip-hop verse"] |
|
} |
|
|
|
Args: |
|
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. |
|
Returns: |
|
dict[str, list[str, optional]]: A dictionary mapping an attribute name to text batch. |
|
""" |
|
out: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list) |
|
texts = [x.text for x in samples] |
|
for text in texts: |
|
for condition in self.text_conditions: |
|
out[condition].append(text[condition]) |
|
return out |
|
|
|
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, WavCondition]: |
|
"""Generate a dict where the keys are attributes by which we fetch similar wavs, |
|
and the values are Tensors of wavs according to said attributes. |
|
|
|
*Note*: by the time the samples reach this function, each sample should have some waveform |
|
inside the "wav" attribute. It should be either: |
|
1. A real waveform |
|
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset) |
|
3. A null waveform due to it being dropped in a dropout module (nullified by dropout) |
|
|
|
Args: |
|
samples (list of ConditioningAttributes): List of ConditioningAttributes samples. |
|
Returns: |
|
dict[str, WavCondition]: A dictionary mapping an attribute name to wavs. |
|
""" |
|
wavs = defaultdict(list) |
|
lengths = defaultdict(list) |
|
sample_rates = defaultdict(list) |
|
paths = defaultdict(list) |
|
seek_times = defaultdict(list) |
|
out: tp.Dict[str, WavCondition] = {} |
|
|
|
for sample in samples: |
|
for attribute in self.wav_conditions: |
|
wav, length, sample_rate, path, seek_time = sample.wav[attribute] |
|
assert wav.dim() == 3, f"Got wav with dim={wav.dim()}, but expected 3 [1, C, T]" |
|
assert wav.size(0) == 1, f"Got wav [B, C, T] with shape={wav.shape}, but expected B == 1" |
|
|
|
wav = wav.mean(1, keepdim=True) |
|
wavs[attribute].append(wav.flatten()) |
|
lengths[attribute].append(length) |
|
sample_rates[attribute].extend(sample_rate) |
|
paths[attribute].extend(path) |
|
seek_times[attribute].extend(seek_time) |
|
|
|
|
|
for attribute in self.wav_conditions: |
|
stacked_wav, _ = collate(wavs[attribute], dim=0) |
|
out[attribute] = WavCondition( |
|
stacked_wav.unsqueeze(1), torch.cat(lengths[attribute]), sample_rates[attribute], |
|
paths[attribute], seek_times[attribute]) |
|
|
|
return out |
|
|
|
def _collate_joint_embeds(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, JointEmbedCondition]: |
|
"""Generate a dict where the keys are attributes by which we compute joint embeddings, |
|
and the values are Tensors of pre-computed embeddings and the corresponding text attributes. |
|
|
|
Args: |
|
samples (list[ConditioningAttributes]): List of ConditioningAttributes samples. |
|
Returns: |
|
A dictionary mapping an attribute name to joint embeddings. |
|
""" |
|
texts = defaultdict(list) |
|
wavs = defaultdict(list) |
|
lengths = defaultdict(list) |
|
sample_rates = defaultdict(list) |
|
paths = defaultdict(list) |
|
seek_times = defaultdict(list) |
|
channels: int = 0 |
|
|
|
out = {} |
|
for sample in samples: |
|
for attribute in self.joint_embed_conditions: |
|
wav, text, length, sample_rate, path, seek_time = sample.joint_embed[attribute] |
|
assert wav.dim() == 3 |
|
if channels == 0: |
|
channels = wav.size(1) |
|
else: |
|
assert channels == wav.size(1), "not all audio has same number of channels in batch" |
|
assert wav.size(0) == 1, "Expecting single-wav batch in the collate method" |
|
wav = einops.rearrange(wav, "b c t -> (b c t)") |
|
wavs[attribute].append(wav) |
|
texts[attribute].extend(text) |
|
lengths[attribute].append(length) |
|
sample_rates[attribute].extend(sample_rate) |
|
paths[attribute].extend(path) |
|
seek_times[attribute].extend(seek_time) |
|
|
|
for attribute in self.joint_embed_conditions: |
|
stacked_texts = texts[attribute] |
|
stacked_paths = paths[attribute] |
|
stacked_seek_times = seek_times[attribute] |
|
stacked_wavs = pad_sequence(wavs[attribute]).to(self.device) |
|
stacked_wavs = einops.rearrange(stacked_wavs, "(c t) b -> b c t", c=channels) |
|
stacked_sample_rates = sample_rates[attribute] |
|
stacked_lengths = torch.cat(lengths[attribute]).to(self.device) |
|
assert stacked_lengths.size(0) == stacked_wavs.size(0) |
|
assert len(stacked_sample_rates) == stacked_wavs.size(0) |
|
assert len(stacked_texts) == stacked_wavs.size(0) |
|
out[attribute] = JointEmbedCondition( |
|
text=stacked_texts, wav=stacked_wavs, |
|
length=stacked_lengths, sample_rate=stacked_sample_rates, |
|
path=stacked_paths, seek_time=stacked_seek_times) |
|
|
|
return out |
|
|
|
|
|
class ConditionFuser(StreamingModule): |
|
"""Condition fuser handles the logic to combine the different conditions |
|
to the actual model input. |
|
|
|
Args: |
|
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse |
|
each condition. For example: |
|
{ |
|
"prepend": ["description"], |
|
"sum": ["genre", "bpm"], |
|
"cross": ["description"], |
|
} |
|
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention. |
|
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used. |
|
""" |
|
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"] |
|
|
|
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False, |
|
cross_attention_pos_emb_scale: float = 1.0): |
|
super().__init__() |
|
assert all( |
|
[k in self.FUSING_METHODS for k in fuse2cond.keys()] |
|
), f"Got invalid fuse method, allowed methods: {self.FUSING_METHODS}" |
|
self.cross_attention_pos_emb = cross_attention_pos_emb |
|
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale |
|
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond |
|
self.cond2fuse: tp.Dict[str, str] = {} |
|
for fuse_method, conditions in fuse2cond.items(): |
|
for condition in conditions: |
|
self.cond2fuse[condition] = fuse_method |
|
|
|
def forward( |
|
self, |
|
input: torch.Tensor, |
|
conditions: tp.Dict[str, ConditionType] |
|
) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]: |
|
"""Fuse the conditions to the provided model input. |
|
|
|
Args: |
|
input (torch.Tensor): Transformer input. |
|
conditions (dict[str, ConditionType]): Dict of conditions. |
|
Returns: |
|
tuple[torch.Tensor, torch.Tensor]: The first tensor is the transformer input |
|
after the conditions have been fused. The second output tensor is the tensor |
|
used for cross-attention or None if no cross attention inputs exist. |
|
""" |
|
B, T, _ = input.shape |
|
|
|
if 'offsets' in self._streaming_state: |
|
first_step = False |
|
offsets = self._streaming_state['offsets'] |
|
else: |
|
first_step = True |
|
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device) |
|
|
|
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \ |
|
f"given conditions contain unknown attributes for fuser, " \ |
|
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}" |
|
cross_attention_output = None |
|
for cond_type, (cond, cond_mask) in conditions.items(): |
|
op = self.cond2fuse[cond_type] |
|
if op == 'sum': |
|
input += cond |
|
elif op == 'input_interpolate': |
|
cond = einops.rearrange(cond, "b t d -> b d t") |
|
cond = F.interpolate(cond, size=input.shape[1]) |
|
input += einops.rearrange(cond, "b d t -> b t d") |
|
elif op == 'prepend': |
|
if first_step: |
|
input = torch.cat([cond, input], dim=1) |
|
elif op == 'cross': |
|
if cross_attention_output is not None: |
|
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1) |
|
else: |
|
cross_attention_output = cond |
|
else: |
|
raise ValueError(f"unknown op ({op})") |
|
|
|
if self.cross_attention_pos_emb and cross_attention_output is not None: |
|
positions = torch.arange( |
|
cross_attention_output.shape[1], |
|
device=cross_attention_output.device |
|
).view(1, -1, 1) |
|
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1]) |
|
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb |
|
|
|
if self._is_streaming: |
|
self._streaming_state['offsets'] = offsets + T |
|
|
|
return input, cross_attention_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
ConditionTensors = tp.Dict[str, ConditionType] |
|
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]] |
|
|
|
|
|
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None): |
|
"""LM layer initialization. |
|
Inspired from xlformers: https://github.com/fairinternal/xlformers |
|
|
|
Args: |
|
method (str): Method name for init function. Valid options are: |
|
'gaussian', 'uniform'. |
|
input_dim (int): Input dimension of the initialized module. |
|
init_depth (int, optional): Optional init depth value used to rescale |
|
the standard deviation if defined. |
|
""" |
|
|
|
std = 1 / math.sqrt(input_dim) |
|
|
|
if init_depth is not None: |
|
std = std / math.sqrt(2 * init_depth) |
|
|
|
if method == 'gaussian': |
|
return partial( |
|
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std |
|
) |
|
elif method == 'uniform': |
|
bound = math.sqrt(3) * std |
|
return partial(torch.nn.init.uniform_, a=-bound, b=bound) |
|
else: |
|
raise ValueError("Unsupported layer initialization method") |
|
|
|
|
|
def init_layer(m: nn.Module, |
|
method: str, |
|
init_depth: tp.Optional[int] = None, |
|
zero_bias_init: bool = False): |
|
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules. |
|
|
|
Args: |
|
m (nn.Module): Module to initialize. |
|
method (str): Method name for the init function. |
|
init_depth (int, optional): Optional init depth value used to rescale |
|
the standard deviation if defined. |
|
zero_bias_init (bool): Whether to initialize the bias to 0 or not. |
|
""" |
|
if isinstance(m, nn.Linear): |
|
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth) |
|
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
|
weight = m.weight.float() |
|
init_fn(weight) |
|
m.weight.data[:] = weight.half() |
|
else: |
|
init_fn(m.weight) |
|
if zero_bias_init and m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Embedding): |
|
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None) |
|
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16: |
|
weight = m.weight.float() |
|
init_fn(weight) |
|
m.weight.data[:] = weight.half() |
|
else: |
|
init_fn(m.weight) |
|
|
|
|
|
class ScaledEmbedding(nn.Embedding): |
|
"""Boost learning rate for embeddings (with `scale`). |
|
""" |
|
def __init__(self, *args, lr=None, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.lr = lr |
|
|
|
def make_optim_group(self): |
|
group = {"params": list(self.parameters())} |
|
if self.lr is not None: |
|
group["lr"] = self.lr |
|
return group |
|
|
|
|
|
@dataclass |
|
class LMOutput: |
|
|
|
|
|
logits: torch.Tensor |
|
mask: torch.Tensor |
|
|
|
|
|
class LMModel(StreamingModule): |
|
"""Transformer-based language model on multiple streams of codes. |
|
|
|
Args: |
|
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving. |
|
condition_provider (MusicConditioningProvider): Conditioning provider from metadata. |
|
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input. |
|
n_q (int): Number of parallel streams to model. |
|
card (int): Cardinality, vocabulary size. |
|
dim (int): Dimension of the transformer encoder. |
|
num_heads (int): Number of heads for the transformer encoder. |
|
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder. |
|
norm (str): Normalization method. |
|
norm_first (bool): Use pre-norm instead of post-norm. |
|
emb_lr (float, optional): Embedding-specific learning rate. |
|
bias_proj (bool): Use bias for output projections. |
|
weight_init (str, optional): Method for weight initialization. |
|
depthwise_init (str, optional): Method for depthwise weight initialization. |
|
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros. |
|
cfg_dropout (float): Classifier-free guidance dropout. |
|
cfg_coef (float): Classifier-free guidance coefficient. |
|
attribute_dropout (dict): Attribute dropout probabilities. |
|
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps. |
|
**kwargs: Additional parameters for the transformer encoder. |
|
""" |
|
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider, |
|
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8, |
|
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False, |
|
emb_lr: tp.Optional[float] = None, bias_proj: bool = True, |
|
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None, |
|
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0, |
|
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False, |
|
**kwargs): |
|
super().__init__() |
|
self.cfg_coef = cfg_coef |
|
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout) |
|
self.att_dropout = AttributeDropout(p=attribute_dropout) |
|
self.condition_provider = condition_provider |
|
self.fuser = fuser |
|
self.card = card |
|
embed_dim = self.card + 1 |
|
self.n_q = n_q |
|
self.dim = dim |
|
self.pattern_provider = pattern_provider |
|
self.two_step_cfg = two_step_cfg |
|
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)]) |
|
if 'activation' in kwargs: |
|
kwargs['activation'] = get_activation_fn(kwargs['activation']) |
|
self.transformer = StreamingTransformer( |
|
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim), |
|
norm=norm, norm_first=norm_first, **kwargs) |
|
self.out_norm: tp.Optional[nn.Module] = None |
|
if norm_first: |
|
self.out_norm = create_norm_fn(norm, dim) |
|
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)]) |
|
self._init_weights(weight_init, depthwise_init, zero_bias_init) |
|
self._fsdp: tp.Optional[nn.Module] |
|
self.__dict__['_fsdp'] = None |
|
|
|
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool): |
|
"""Initialization of the transformer module weights. |
|
|
|
Args: |
|
weight_init (str, optional): Weight initialization strategy. See ``get_init_fn`` for valid options. |
|
depthwise_init (str, optional): Depthwise initialization strategy. The following options are valid: |
|
'current' where the depth corresponds to the current layer index or 'global' where the total number |
|
of layer is used as depth. If not set, no depthwise initialization strategy is used. |
|
zero_bias_init (bool): Whether to initialize bias to zero or not. |
|
""" |
|
assert depthwise_init is None or depthwise_init in ['current', 'global'] |
|
assert depthwise_init is None or weight_init is not None, \ |
|
"If 'depthwise_init' is defined, a 'weight_init' method should be provided." |
|
assert not zero_bias_init or weight_init is not None, \ |
|
"If 'zero_bias_init', a 'weight_init' method should be provided" |
|
|
|
if weight_init is None: |
|
return |
|
|
|
for emb_layer in self.emb: |
|
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) |
|
|
|
for layer_idx, tr_layer in enumerate(self.transformer.layers): |
|
depth = None |
|
if depthwise_init == 'current': |
|
depth = layer_idx + 1 |
|
elif depthwise_init == 'global': |
|
depth = len(self.transformer.layers) |
|
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init) |
|
tr_layer.apply(init_fn) |
|
|
|
for linear in self.linears: |
|
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init) |
|
|
|
@property |
|
def special_token_id(self) -> int: |
|
return self.card |
|
|
|
@property |
|
def num_codebooks(self) -> int: |
|
return self.n_q |
|
|
|
def forward(self, sequence: torch.Tensor, |
|
conditions: tp.List[ConditioningAttributes], |
|
condition_tensors: tp.Optional[ConditionTensors] = None, |
|
stage: int = -1) -> torch.Tensor: |
|
"""Apply language model on sequence and conditions. |
|
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and |
|
S the sequence steps, return the logits with shape [B, card, K, S]. |
|
|
|
Args: |
|
indices (torch.Tensor): Indices of the codes to model. |
|
conditions (list of ConditioningAttributes): Conditions to use when modeling |
|
the given codes. Note that when evaluating multiple time with the same conditioning |
|
you should pre-compute those and pass them as `condition_tensors`. |
|
condition_tensors (dict[str, ConditionType], optional): Pre-computed conditioning |
|
tensors, see `conditions`. |
|
stage (int): The codebook level that is being predicted. Relevant for MAGNeT |
|
in which prediction is done in a codebook-by-codebook manner. |
|
Takes values in range(n_q), and ignored by default. |
|
Returns: |
|
torch.Tensor: Logits. |
|
""" |
|
B, K, S = sequence.shape |
|
assert K == self.num_codebooks, "Sequence shape must match the specified number of codebooks" |
|
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)]) |
|
if condition_tensors is None: |
|
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming." |
|
|
|
conditions = self.cfg_dropout(conditions) |
|
conditions = self.att_dropout(conditions) |
|
tokenized = self.condition_provider.tokenize(conditions) |
|
|
|
condition_tensors = self.condition_provider(tokenized) |
|
else: |
|
assert not conditions, "Shouldn't pass both conditions and condition_tensors." |
|
|
|
input_, cross_attention_input = self.fuser(input_, condition_tensors) |
|
|
|
out = self.transformer(input_, cross_attention_src=cross_attention_input, |
|
src_mask=(self.attn_mask_per_stage[stage] if stage >= 0 else None)) |
|
if self.out_norm: |
|
out = self.out_norm(out) |
|
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) |
|
|
|
|
|
if len(self.fuser.fuse2cond['prepend']) > 0: |
|
logits = logits[:, :, -S:] |
|
|
|
return logits |
|
|
|
def compute_predictions( |
|
self, codes: torch.Tensor, |
|
conditions: tp.List[ConditioningAttributes], |
|
condition_tensors: tp.Optional[ConditionTensors] = None, |
|
stage: int = -1, |
|
keep_only_valid_steps: bool = True) -> LMOutput: |
|
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model |
|
forward using the specified codes interleaving pattern. |
|
|
|
Args: |
|
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size, |
|
K the number of codebooks and T the number of timesteps. |
|
conditions (list of ConditioningAttributes): conditionings to use when modeling |
|
the given codes. Note that when evaluating multiple time with the same conditioning |
|
you should pre-compute those and pass them as `condition_tensors`. |
|
condition_tensors (dict[str, ConditionType], optional): pre-computed conditioning |
|
tensors, see `conditions`. |
|
stage (int): The codebook level that is being predicted. Relevant for MAGNeT |
|
in which prediction is done in a codebook-by-codebook manner. |
|
Takes values in range(n_q), and ignored by default. |
|
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps. |
|
Steps that are beyond valid steps will be replaced by the special_token in that case. |
|
Returns: |
|
LMOutput: Language model outputs |
|
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes, |
|
i.e. the first item corresponds to logits to predict the first code, meaning that |
|
no additional shifting of codes and logits is required. |
|
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions. |
|
Given the specified interleaving strategies, parts of the logits and codes should |
|
not be considered as valid predictions because of invalid context. |
|
""" |
|
B, K, T = codes.shape |
|
codes = codes.contiguous() |
|
|
|
pattern = self.pattern_provider.get_pattern(T) |
|
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence( |
|
codes, self.special_token_id, keep_only_valid_steps=keep_only_valid_steps, |
|
) |
|
|
|
|
|
model = self if self._fsdp is None else self._fsdp |
|
logits = model(sequence_codes, conditions, condition_tensors, stage=stage) |
|
|
|
|
|
logits = logits.permute(0, 3, 1, 2) |
|
|
|
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits( |
|
logits, float('nan'), keep_only_valid_steps=keep_only_valid_steps |
|
) |
|
logits = logits.permute(0, 2, 3, 1) |
|
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) |
|
return LMOutput(logits, logits_mask) |
|
|
|
def _sample_next_token(self, |
|
sequence, |
|
cfg_conditions, |
|
unconditional_state, |
|
use_sampling=False, |
|
temp: float = 1.0, |
|
top_k: int = 0, |
|
top_p: float = 0.0, |
|
cfg_coef: tp.Optional[float] = None, |
|
two_step_cfg: tp.Optional[bool] = None) -> torch.Tensor: |
|
"""Sample next token from the model given a sequence and a set of conditions. The model supports |
|
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...). |
|
|
|
Args: |
|
sequence (torch.Tensor): Current sequence of shape [B, K, S] |
|
with K corresponding to the number of codebooks and S the number of sequence steps. |
|
S = 1 in streaming mode, except for the first step that contains a bigger prompt. |
|
condition_tensors (dict[str, ConditionType): Set of conditions. If CFG is used, |
|
should be twice the batch size, being the concatenation of the conditions + null conditions. |
|
use_sampling (bool): Whether to use a sampling strategy or not. |
|
temp (float): Sampling temperature. |
|
top_k (int): K for "top-k" sampling. |
|
top_p (float): P for "top-p" sampling. |
|
cfg_coef (float, optional): classifier free guidance coefficient |
|
Returns: |
|
next_token (torch.Tensor): Next token tensor of shape [B, K, 1]. |
|
""" |
|
B = sequence.shape[0] |
|
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef |
|
model = self if self._fsdp is None else self._fsdp |
|
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg |
|
if two_step_cfg and cfg_conditions != {}: |
|
assert isinstance(cfg_conditions, tuple), type(cfg_conditions) |
|
condition_tensors, null_condition_tensors = cfg_conditions |
|
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors) |
|
state = self.get_streaming_state() |
|
self.set_streaming_state(unconditional_state) |
|
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors) |
|
unconditional_state.update(self.get_streaming_state()) |
|
self.set_streaming_state(state) |
|
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef |
|
else: |
|
assert isinstance(cfg_conditions, dict) |
|
condition_tensors = cfg_conditions |
|
if condition_tensors: |
|
|
|
sequence = torch.cat([sequence, sequence], dim=0) |
|
all_logits = model( |
|
sequence, |
|
conditions=[], condition_tensors=condition_tensors) |
|
if condition_tensors: |
|
cond_logits, uncond_logits = all_logits.split(B, dim=0) |
|
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef |
|
else: |
|
logits = all_logits |
|
|
|
logits = logits.permute(0, 1, 3, 2) |
|
logits = logits[..., -1] |
|
|
|
|
|
if use_sampling and temp > 0.0: |
|
probs = torch.softmax(logits / temp, dim=-1) |
|
if top_p > 0.0: |
|
next_token = utils.sample_top_p(probs, p=top_p) |
|
elif top_k > 0: |
|
next_token = utils.sample_top_k(probs, k=top_k) |
|
else: |
|
next_token = utils.multinomial(probs, num_samples=1) |
|
else: |
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
return next_token |
|
|
|
@torch.no_grad() |
|
def generate(self, |
|
prompt: tp.Optional[torch.Tensor] = None, |
|
conditions: tp.List[ConditioningAttributes] = [], |
|
num_samples: tp.Optional[int] = None, |
|
max_gen_len: int = 256, |
|
use_sampling: bool = True, |
|
temp: float = 1.0, |
|
top_k: int = 250, |
|
top_p: float = 0.0, |
|
cfg_coef: tp.Optional[float] = None, |
|
two_step_cfg: tp.Optional[bool] = None, |
|
remove_prompts: bool = False, |
|
check: bool = False, |
|
callback: tp.Optional[tp.Callable[[int, int], None]] = None, |
|
**kwargs) -> torch.Tensor: |
|
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can |
|
be performed in a greedy fashion or using sampling with top K and top P strategies. |
|
|
|
Args: |
|
prompt (torch.Tensor, optional): Prompt tokens of shape [B, K, T]. |
|
conditions_tensors (list of ConditioningAttributes, optional): List of conditions. |
|
num_samples (int, optional): Number of samples to generate when no prompt and no conditions are given. |
|
max_gen_len (int): Maximum generation length. |
|
use_sampling (bool): Whether to use a sampling strategy or not. |
|
temp (float): Sampling temperature. |
|
top_k (int): K for "top-k" sampling. |
|
top_p (float): P for "top-p" sampling. |
|
cfg_coeff (float, optional): Classifier-free guidance coefficient. |
|
two_step_cfg (bool, optional): Whether to perform classifier-free guidance with two steps generation. |
|
remove_prompts (bool): Whether to remove prompts from generation or not. |
|
check (bool): Whether to apply further checks on generated sequence. |
|
callback (Callback, optional): Callback function to report generation progress. |
|
Returns: |
|
torch.Tensor: Generated tokens. |
|
""" |
|
assert not self.training, "generation shouldn't be used in training mode." |
|
first_param = next(iter(self.parameters())) |
|
device = first_param.device |
|
|
|
|
|
possible_num_samples = [] |
|
if num_samples is not None: |
|
possible_num_samples.append(num_samples) |
|
elif prompt is not None: |
|
possible_num_samples.append(prompt.shape[0]) |
|
elif conditions: |
|
possible_num_samples.append(len(conditions)) |
|
else: |
|
possible_num_samples.append(1) |
|
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsistent inputs shapes" |
|
num_samples = possible_num_samples[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cfg_conditions: CFGConditions |
|
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg |
|
if conditions: |
|
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions) |
|
if two_step_cfg: |
|
cfg_conditions = ( |
|
self.condition_provider(self.condition_provider.tokenize(conditions)), |
|
self.condition_provider(self.condition_provider.tokenize(null_conditions)), |
|
) |
|
else: |
|
conditions = conditions + null_conditions |
|
tokenized = self.condition_provider.tokenize(conditions) |
|
cfg_conditions = self.condition_provider(tokenized) |
|
else: |
|
cfg_conditions = {} |
|
|
|
if prompt is None: |
|
assert num_samples > 0 |
|
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device) |
|
|
|
B, K, T = prompt.shape |
|
start_offset = T |
|
assert start_offset < max_gen_len |
|
|
|
pattern = self.pattern_provider.get_pattern(max_gen_len) |
|
|
|
unknown_token = -1 |
|
|
|
|
|
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device) |
|
|
|
gen_codes[..., :start_offset] = prompt |
|
|
|
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id) |
|
|
|
|
|
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset) |
|
assert start_offset_sequence is not None |
|
|
|
with self.streaming(): |
|
unconditional_state = self.get_streaming_state() |
|
prev_offset = 0 |
|
gen_sequence_len = gen_sequence.shape[-1] |
|
for offset in range(start_offset_sequence, gen_sequence_len): |
|
|
|
curr_sequence = gen_sequence[..., prev_offset:offset] |
|
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1) |
|
if check: |
|
|
|
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all() |
|
|
|
assert not (curr_sequence == unknown_token).any() |
|
|
|
next_token = self._sample_next_token( |
|
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p, |
|
cfg_coef=cfg_coef, two_step_cfg=two_step_cfg) |
|
|
|
|
|
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1) |
|
next_token[~valid_mask] = self.special_token_id |
|
|
|
|
|
gen_sequence[..., offset:offset+1] = torch.where( |
|
gen_sequence[..., offset:offset+1] == unknown_token, |
|
next_token, gen_sequence[..., offset:offset+1] |
|
) |
|
prev_offset = offset |
|
if callback is not None: |
|
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence) |
|
unconditional_state.clear() |
|
|
|
|
|
assert not (gen_sequence == unknown_token).any() |
|
|
|
|
|
assert ( |
|
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id) |
|
).all() |
|
|
|
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token) |
|
|
|
|
|
assert (out_codes[..., :max_gen_len] != unknown_token).all() |
|
assert (out_mask[..., :max_gen_len] == 1).all() |
|
|
|
out_start_offset = start_offset if remove_prompts else 0 |
|
out_codes = out_codes[..., out_start_offset:max_gen_len] |
|
|
|
|
|
assert (out_codes >= 0).all() and (out_codes <= self.card).all() |
|
return out_codes |
|
|