| import random |
| from dataclasses import dataclass |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange |
| from module import SinePositionalEmbedding, TokenEmbedding, top_k_sampling |
| from simple_parsing import Serializable |
| from torch.cuda.amp import autocast |
| from torchmetrics.classification import MulticlassAccuracy |
| from transformer import AdaptiveLayerNorm, LayerNorm, TransformerEncoder, TransformerEncoderLayer |
| from transformers import EncodecModel |
| from vocos import Vocos |
|
|
| from audiozen.acoustics.audio_feature import stft |
|
|
|
|
| @dataclass |
| class ModelArgs(Serializable): |
| num_cb: int = 8 |
| cb_size: int = 1024 |
| d_model: int = 512 |
| n_fft: int = 768 |
| hop_length: int = 384 |
| num_tokens: int = 1024 |
| num_layers: int = 12 |
| num_heads: int = 8 |
| norm_first: bool = True |
| share_embedding: bool = True |
| prepend_bos: bool = False |
| add_prenet: bool = False |
| stage: int = 2 |
|
|
|
|
| class Model(nn.Module): |
| def __init__(self, args: ModelArgs = ModelArgs()): |
| super().__init__() |
| self.encodec = EncodecModel.from_pretrained("facebook/encodec_24khz") |
| for param in self.encodec.parameters(): |
| param.requires_grad = False |
|
|
| |
| self.ar_audio_prepend_bos = args.prepend_bos |
| self.ar_embedding_layer = TokenEmbedding(args.d_model, args.num_tokens + 1 + int(args.prepend_bos)) |
| self.ar_spec_encoder = nn.Sequential( |
| nn.Linear(args.n_fft // 2 + 1, 256), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, 256), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, args.d_model), |
| ) |
| self.ar_prenet = nn.Identity() |
| self.ar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=True) |
|
|
| |
| self.ar_decoder = TransformerEncoder( |
| TransformerEncoderLayer( |
| args.d_model, |
| args.num_heads, |
| dim_feedforward=args.d_model * 4, |
| dropout=0.1, |
| batch_first=True, |
| norm_first=args.norm_first, |
| ), |
| num_layers=args.num_layers, |
| norm=LayerNorm(args.d_model) if args.norm_first else None, |
| ) |
|
|
| self.ar_pred_layer = nn.Linear(args.d_model, args.num_tokens + 1, bias=False) |
|
|
| self.ar_acc_metric = MulticlassAccuracy( |
| args.num_tokens + 1, |
| top_k=10, |
| average="micro", |
| multidim_average="global", |
| ignore_index=args.num_tokens, |
| ) |
|
|
| |
| self.nar_spec_encoder = nn.Sequential( |
| nn.Linear(args.n_fft // 2 + 1, 256), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, 256), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(256, args.d_model), |
| ) |
| self.nar_embedding_layers = nn.ModuleList( |
| [TokenEmbedding(args.d_model, args.num_tokens + 1)] |
| + [TokenEmbedding(args.d_model, args.num_tokens) for _ in range(args.num_cb - 1)] |
| ) |
| self.nar_prenet = nn.Identity() |
| self.nar_position = SinePositionalEmbedding(args.d_model, dropout=0.1, scale=False, alpha=False) |
|
|
| |
| self.nar_decoder = TransformerEncoder( |
| TransformerEncoderLayer( |
| args.d_model, |
| args.num_heads, |
| dim_feedforward=args.d_model * 4, |
| dropout=0.1, |
| batch_first=True, |
| norm_first=args.norm_first, |
| adaptive_layer_norm=True, |
| ), |
| num_layers=args.num_layers, |
| norm=AdaptiveLayerNorm(args.d_model, norm=nn.LayerNorm(args.d_model)) if args.norm_first else None, |
| ) |
|
|
| self.nar_pred_layers = nn.ModuleList( |
| [nn.Linear(args.d_model, args.num_tokens, bias=False) for _ in range(args.num_cb - 1)] |
| ) |
|
|
| self.nar_stage_embeddings = nn.ModuleList([TokenEmbedding(args.d_model, 1) for i in range(args.num_cb - 1)]) |
|
|
| if args.share_embedding: |
| for j in range(0, args.num_cb - 2): |
| |
| |
| self.nar_pred_layers[j].weight = self.nar_embedding_layers[j + 2].weight |
|
|
| self.nar_acc_metric = MulticlassAccuracy( |
| args.num_tokens + 1, |
| top_k=10, |
| average="micro", |
| multidim_average="global", |
| ignore_index=args.num_tokens, |
| ) |
|
|
| self.args = args |
| self.rng = random.Random(0) |
|
|
| def _encodec_encode(self, waveform): |
| """Encode waveform to codes. |
| Args: |
| waveform: shape of [B, T] |
| Returns: |
| codes: shape of [B, T, N_q] |
| """ |
| with torch.no_grad(): |
| with autocast(dtype=torch.float32): |
| waveform = rearrange(waveform, "b t -> b () t") |
| codes = self.encodec.encode(input_values=waveform, return_dict=True, bandwidth=6) |
| codes = codes.audio_codes |
| codes = rearrange(codes, "c b nq t -> (c b) t nq") |
| codes = codes.to(dtype=torch.long) |
| return codes |
|
|
| def _encodec_decode(self, codes): |
| """codes with shape [B, T, N_q] => [B, T]""" |
| with torch.no_grad(): |
| codes = rearrange(codes, "b t nq -> 1 b nq t") |
| audio_values = self.encodec.decode(audio_codes=codes, audio_scales=[None], return_dict=True).audio_values |
| audio_values = rearrange(audio_values, "b 1 t -> b t") |
| return audio_values |
|
|
| def _vocos_decode(self, codes): |
| """codes with shape [B, T, N_q] => [B, T]""" |
| with torch.no_grad(): |
| vocos = Vocos.from_pretrained("charactr/vocos-encodec-24khz").to(codes.device) |
|
|
| codes = codes.permute(2, 0, 1) |
| features = vocos.codes_to_features(codes) |
| audio_values = vocos.decode(features, bandwidth_id=torch.tensor([2], device=codes.device)) |
| return audio_values |
|
|
| def _stack_embeddings(self, codes, cb_idx): |
| """Stack embeddings from stage 0 to stage `cb_idx`. |
| Args: |
| codes: shape of [B, T, N_q] |
| cb_idx: codebook index |
| |
| Returns: |
| cumsum: shape of [B, T, cb_size] |
| """ |
| batch_size, time_steps, _ = codes.shape |
| embed = torch.zeros(batch_size, time_steps, self.args.cb_size, device=codes.device) |
| for i in range(0, cb_idx + 1): |
| emb_i = self.embed_layers[i](codes[..., i]) |
| embed += emb_i |
| return embed |
|
|
| def _pad_eos(self, codes, eos_id): |
| """Pad EOS to codes with shape [B, T] => [B, T+1]""" |
| codes = F.pad(codes, (0, 1), value=eos_id) |
|
|
| return codes[..., :-1], codes[..., 1:] |
|
|
| def forward(self, mix_wave, sep_wav): |
| """Auto-regressive model to predict codes of the separated waveform. |
| |
| Args: |
| mix_wave: mixture waveform, shape of [B, T] |
| sep_wave: sep waveform, shape of [B, 2, T] |
| """ |
| |
| batch_size, num_spks, _ = sep_wav.shape |
| device = sep_wav.device |
| assert num_spks == 2, f"Only support to process two speakers, but got {num_spks}" |
|
|
| |
| mix_codes = self._encodec_encode(mix_wave) |
| spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1]) |
| sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) |
|
|
| |
| |
| sep_code, target = self._pad_eos(sep_codes[..., 0], self.args.num_tokens) |
|
|
| |
| mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) |
| for j in range(1, self.args.num_cb): |
| mix_embed += self.ar_embedding_layer(mix_codes[..., j]) |
| sep_embed = self.ar_embedding_layer(sep_code) |
|
|
| |
| mix_mag, *_ = stft( |
| mix_wave, n_fft=self.args.n_fft, hop_length=self.args.hop_length, win_length=self.args.n_fft |
| ) |
| mix_mag = rearrange(mix_mag, "b f t -> b t f") |
| mix_mag = torch.log(mix_mag + 1e-8) |
| mix_feat = self.ar_spec_encoder(mix_mag) |
|
|
| mix_len = mix_feat.shape[1] + mix_embed.shape[1] |
| sep_len = sep_embed.shape[1] |
| mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) |
| mix_sep_embed = self.ar_prenet(mix_sep_embed) |
| mix_sep_embed = self.ar_position(mix_sep_embed) |
|
|
| mix_attn_mask = F.pad( |
| torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device), (0, sep_len), value=True |
| ) |
|
|
| sep_attn_mask = F.pad( |
| torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1), |
| (mix_len, 0), |
| value=False, |
| ) |
|
|
| mix_sep_attn_mask = torch.cat([mix_attn_mask, sep_attn_mask], dim=0) |
|
|
| mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) |
| logits = self.ar_pred_layer(mix_sep_dec[:, mix_len:]) |
| logits = rearrange(logits, "b t h -> b h t") |
| ar_loss = F.cross_entropy(logits, target, reduction="mean") |
| ar_accuracy_metric = self.ar_acc_metric(logits.detach(), target) |
| ar_accuracy_metric = ar_accuracy_metric.detach().cpu().numpy().item() |
|
|
| if self.args.stage == 1: |
| return ar_loss, ar_loss, 0.0, ar_accuracy_metric, 0.0 |
| |
| num_nar_layers = self.args.num_cb - 1 |
| nar_stage = self.rng.choices( |
| list(range(1, self.args.num_cb)), |
| weights=[1.0 / num_nar_layers] * num_nar_layers, |
| k=1, |
| )[0] |
|
|
| |
| mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) |
| sep_embed = self.nar_embedding_layers[0](sep_code) |
|
|
| for j in range(1, self.args.num_cb): |
| mix_embed += self.nar_embedding_layers[j](mix_codes[..., j]) |
| if j < nar_stage: |
| sep_embed += self.nar_embedding_layers[j](sep_codes[..., j]) |
|
|
| mix_feat = self.nar_spec_encoder(mix_mag) |
|
|
| mix_sep_embed = torch.cat([mix_feat, mix_embed, sep_embed], dim=1) |
| mix_sep_embed = self.nar_prenet(mix_sep_embed) |
| mix_sep_embed = self.nar_position(mix_sep_embed) |
|
|
| target = sep_codes[..., nar_stage] |
|
|
| mix_sep_dec, _ = self.nar_decoder( |
| (mix_sep_embed, self.nar_stage_embeddings[nar_stage - 1].weight), |
| src_key_padding_mask=None, |
| |
| ) |
|
|
| mix_sep_dec = mix_sep_dec[:, mix_len:] |
| logits = self.nar_pred_layers[nar_stage - 1](mix_sep_dec).permute(0, 2, 1) |
|
|
| |
| nar_loss = F.cross_entropy(logits, target, ignore_index=self.args.num_tokens, reduction="mean") |
|
|
| nar_acc_metric = self.nar_acc_metric( |
| F.pad( |
| logits.detach(), |
| (0, 0, 0, 1, 0, 0), |
| value=logits.min().cpu().item(), |
| ), |
| target, |
| ).item() |
|
|
| total_loss = ar_loss + nar_loss |
|
|
| return total_loss, ar_loss, nar_loss, ar_accuracy_metric, nar_acc_metric |
|
|
| def generate(self, mix_wave, sep_wav, top_k=-100, temperature=1.0): |
| """Generate separated waveforms from mixture waveform. |
| |
| Args: |
| mix_wave: mixture waveform, shape of [B, T] |
| """ |
| batch_size, seq_len = mix_wave.shape |
| device = mix_wave.device |
| assert batch_size == 1, f"Only support batch size 1, but got {batch_size}." |
|
|
| mix_codes = self._encodec_encode(mix_wave) |
| spk_1_codes, spk_2_codes = self._encodec_encode(sep_wav[:, 0]), self._encodec_encode(sep_wav[:, 1]) |
| sep_codes = torch.cat([spk_1_codes, spk_2_codes], dim=1) |
| sep_codes = sep_codes[:, 0:1, :] |
|
|
| |
| sep_code = sep_codes[..., 0] |
|
|
| |
| |
| |
| mix_embed = self.ar_embedding_layer(mix_codes[..., 0]) |
| for j in range(1, self.args.num_cb): |
| mix_embed += self.ar_embedding_layer(mix_codes[..., j]) |
|
|
| mix_len = mix_embed.shape[1] |
| mix_attn_mask = torch.zeros((mix_len, mix_len), dtype=torch.bool, device=device) |
|
|
| while True: |
| sep_embed = self.ar_embedding_layer(sep_code) |
| mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) |
| mix_sep_embed = self.ar_prenet(mix_sep_embed) |
| mix_sep_embed = self.ar_position(mix_sep_embed) |
|
|
| sep_len = sep_code.shape[1] |
|
|
| |
| mix_attn_mask_pad = F.pad(mix_attn_mask, (0, sep_len), value=True) |
|
|
| sep_attn_mask = F.pad( |
| torch.triu(torch.ones(sep_len, sep_len, dtype=torch.bool, device=device), diagonal=1), |
| (mix_len, 0), |
| value=False, |
| ) |
|
|
| mix_sep_attn_mask = torch.cat([mix_attn_mask_pad, sep_attn_mask], dim=0) |
|
|
| mix_sep_dec, _ = self.ar_decoder((mix_sep_embed, None), mask=mix_sep_attn_mask) |
| logits = self.ar_pred_layer(mix_sep_dec[:, -1]) |
|
|
| |
| samples = top_k_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature) |
|
|
| if ( |
| (torch.argmax(logits, dim=-1)[0] == self.args.num_tokens) |
| or (samples[0][0] == self.args.num_tokens) |
| or (sep_code.shape[1] >= 2 * mix_len) |
| ): |
| print(f"EOS token reached at {sep_code.shape[1]}") |
| break |
|
|
| sep_code = torch.cat([sep_code, samples], dim=-1) |
|
|
| |
| codes = [sep_code] |
| |
| |
| |
| |
| sep_embed = self.nar_embedding_layers[0](sep_code) |
|
|
| |
| mix_embed = self.nar_embedding_layers[0](mix_codes[..., 0]) |
| for j in range(1, self.args.num_cb): |
| mix_embed += self.nar_embedding_layers[j](mix_codes[..., j]) |
|
|
| for i, (pred_layer, embed_layer) in enumerate(zip(self.nar_pred_layers, self.nar_embedding_layers[1:])): |
| mix_sep_embed = torch.cat([mix_embed, sep_embed], dim=1) |
| mix_sep_embed = self.nar_prenet(mix_sep_embed) |
| mix_sep_embed = self.nar_position(mix_sep_embed) |
|
|
| mix_sep_dec, _ = self.nar_decoder( |
| (mix_sep_embed, self.nar_stage_embeddings[i].weight), |
| src_key_padding_mask=None, |
| |
| ) |
|
|
| mix_sep_dec = mix_sep_dec[:, mix_len:] |
| logits = pred_layer(mix_sep_dec) |
| samples = torch.argmax(logits, dim=-1) |
| codes.append(samples) |
|
|
| if i < self.args.num_cb - 2: |
| sep_embed += embed_layer(samples) |
|
|
| assert len(codes) == self.args.num_cb |
| codes = torch.stack(codes, dim=-1) |
|
|
| |
| sep_wave = self._vocos_decode(codes) |
|
|
| return sep_wave |
|
|
|
|
| if __name__ == "__main__": |
| model = Model() |
| input = torch.rand(2, 22000) |
| target = torch.rand(2, 2, 22000) |
| output = model(input, target) |
| print(output) |
| input = torch.rand(1, 22000) |
| sep_wave = torch.rand(1, 2, 22000) |
| audio_values = model.generate(input, sep_wave) |
| print(audio_values.shape) |
|
|
| |
|
|