Spaces:
Running
Running
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
import torchaudio | |
import numpy as np | |
import time | |
from .valle_ar_trainer import ValleARTrainer, make_pad_mask | |
class ValleNARTrainer(ValleARTrainer): | |
def __init__(self, args=None, cfg=None): | |
super().__init__(args, cfg) | |
print("simple NAR") | |
self.top1_accuracies = { | |
1: [], | |
2: [], | |
3: [], | |
4: [], | |
5: [], | |
6: [], | |
7: [], | |
} | |
self.top5_accuracies = { | |
1: [], | |
2: [], | |
3: [], | |
4: [], | |
5: [], | |
6: [], | |
7: [], | |
} | |
self.top10_accuracies = { | |
1: [], | |
2: [], | |
3: [], | |
4: [], | |
5: [], | |
6: [], | |
7: [], | |
} | |
def _build_model(self): | |
from .valle_nar import ValleNAR | |
return ValleNAR(**self.cfg.model) | |
def _train_step(self, batch): | |
# inference codec | |
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') | |
speech: [B, T] | |
speech_len: [B] | |
phone_ids: [B, T] | |
phone_lens: [B] | |
""" | |
device = self.accelerator.device | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
with torch.no_grad(): | |
if self.cfg.use_speechtokenizer: | |
# Extract discrete codes from SpeechTokenizer | |
# 16k | |
vq_id = self.codec_encoder.encode( | |
batch["speech"].unsqueeze(1) | |
) # [B,T] -> (n_q, B, T) | |
# RVQ_1 = codes[:1, :, :] # Contain content info, can be considered as semantic tokens | |
# RVQ_supplement = codes[1:, :, :] # Contain timbre info, complete info lost by the first quantizer | |
# Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding | |
# wav = self.codec_encoder.decode(vq_id) | |
# torchaudio.save('a.wav', wav[0].cpu(), 16000) | |
# # Decoding from RVQ-i:j tokens from the ith quantizers to the jth quantizers | |
# wav = model.decode(codes[i: (j + 1)], st=i) | |
else: | |
# using encodec, 24k | |
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
0, 1 | |
) | |
# recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
# torchaudio.save('a.wav', recovered_audio[0], 16000) | |
# vq_id: [8, B, T//320] | |
batch["speech"] = vq_id | |
batch["speech_len"] = batch["speech_len"] // 320 # our codec downsamples 320x | |
assert batch["speech_len"].max() <= batch["speech"].shape[-1] | |
phone_mask = 1 - make_pad_mask( | |
batch["phone_lens"], max_len=batch["phone_ids"].size(1), left_pad=False | |
).to(torch.long) | |
speech_mask = 1 - make_pad_mask( | |
batch["speech_len"], max_len=batch["speech"].size(-1) | |
).to(torch.long) | |
np.random.seed(int(time.time()) - 5 * self.accelerator.process_index) | |
if hasattr(self.cfg.train, "dropout"): | |
dropout = self.cfg.train.dropout | |
else: | |
dropout = 0.0 | |
out = self.model( | |
phone_ids=batch["phone_ids"], | |
phone_mask=phone_mask, | |
target_ids=batch["speech"], | |
target_mask=speech_mask, | |
dropout=dropout, | |
) | |
loss = out.loss | |
self.accelerator.log( | |
{f"Train/NAR L{out.target_quantization_layer} Top1 acc": out.top1_acc}, | |
step=self.step, | |
) | |
self.accelerator.log( | |
{f"Train/NAR L{out.target_quantization_layer} Top5 acc": out.top5_acc}, | |
step=self.step, | |
) | |
self.accelerator.log( | |
{f"Train/NAR L{out.target_quantization_layer} Top10 acc": out.top10_acc}, | |
step=self.step, | |
) | |
# if hasattr(out, 'top1_acc'): | |
# idx = out.target_quantization_layer | |
# self.top1_accuracies[idx].append(out.top1_acc) | |
# self.top5_accuracies[idx].append(out.top5_acc) | |
# self.top10_accuracies[idx].append(out.top10_acc) | |
# if len(self.top1_accuracies[idx]) >= 160: | |
# breakpoint() | |
# if self.accelerator.is_main_process: | |
# print(loss) | |
return loss | |
def _test_step(self, batch): | |
# inference codec | |
"""Returns: dict('speech', 'speech_len', 'phone_ids', 'phone_lens') | |
speech: [B, T] | |
speech_len: [B] | |
phone_ids: [B, T] | |
phone_lens: [B] | |
""" | |
import torchaudio | |
device = self.accelerator.device | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
with torch.no_grad(): | |
if self.cfg.use_speechtokenizer: | |
# Extract discrete codes from SpeechTokenizer | |
# 16k | |
vq_id = self.codec_encoder.encode( | |
batch["speech"].unsqueeze(1) | |
) # [B,1,T] -> (n_q, B, T) | |
# Concatenating semantic tokens (RVQ_1) and supplementary timbre tokens and then decoding | |
# wav = self.codec_encoder.decode(vq_id) | |
# torchaudio.save('a.wav', wav[0].cpu(), 16000) | |
else: | |
vq_id = self.codec_encoder.encode(batch["speech"].unsqueeze(1)) | |
vq_id = torch.cat([encoded[0] for encoded in vq_id], dim=-1).transpose( | |
0, 1 | |
) | |
# recovered_audio = self.codec_encoder.decode([(vq_id.transpose(0,1), None)]) | |
# recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
# torchaudio.save('a.wav', recovered_audio[0], 16000) | |
# vq_id: [8, B, T//200] | |
# vq_emb = self.codec_decoder.quantizer.vq2emb(vq=vq_id[:1], n_quantizers=1) | |
# recovered_audio = self.codec_decoder(vq_emb, vq=False) | |
# recovered_audio.shape: torch.Size([1, 1, 50200]) | |
batch["speech"] = vq_id | |
# save gt | |
if self.cfg.use_speechtokenizer: | |
recovered_audio = self.codec_encoder.decode(vq_id) | |
else: | |
recovered_audio = self.codec_encoder.decode( | |
[(vq_id.transpose(0, 1), None)] | |
) | |
torchaudio.save("gt.wav", recovered_audio[0].cpu(), 16000) | |
self.model.eval() | |
out_vq_ids = self.model.sample_hf( | |
phone_ids=batch["phone_ids"][:1], | |
prompt_ids=batch["speech"][:, :1, :150], | |
first_stage_ids=batch["speech"][0, :1, 150:], | |
) | |
# breakpoint() | |
# out_vq_ids = torch.cat([batch['speech'][:, :225], out_vq_ids], dim=1) | |
# reconstruct form tokens | |
if self.cfg.use_speechtokenizer: | |
recovered_audio = self.codec_encoder.decode(out_vq_ids) | |
else: | |
recovered_audio = self.codec_encoder.decode( | |
[(out_vq_ids.transpose(0, 1)[:1], None)] | |
) | |
torchaudio.save("a.wav", recovered_audio[0].cpu(), 16000) | |
breakpoint() | |