|
|
|
import yaml |
|
import logging |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
from .nn.feature_extractor import MelFeatureExtractor |
|
from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos |
|
from .nn.quantizer import ResidualVQ |
|
|
|
class XY_Tokenizer(nn.Module): |
|
def __init__(self, generator_params): |
|
super().__init__() |
|
|
|
self.input_sample_rate = generator_params['input_sample_rate'] |
|
self.output_sample_rate = generator_params['output_sample_rate'] |
|
|
|
self.encoder_downsample_rate = 1280 |
|
self.decoder_upsample_rate = 1920 |
|
self.code_dim = generator_params['quantizer_kwargs']['input_dim'] |
|
|
|
|
|
|
|
|
|
self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs']) |
|
|
|
self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs']) |
|
|
|
|
|
self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs']) |
|
|
|
|
|
self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs']) |
|
|
|
self.downsample = ResidualDownConv(**generator_params['downsample_kwargs']) |
|
|
|
self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs']) |
|
self.nq = generator_params['quantizer_kwargs']['num_quantizers'] |
|
|
|
self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs']) |
|
|
|
|
|
self.upsample = UpConv(**generator_params['upsample_kwargs']) |
|
|
|
self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs']) |
|
|
|
self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs']) |
|
|
|
|
|
self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs']) |
|
|
|
@torch.inference_mode() |
|
def inference_tokenize(self, x, input_lengths): |
|
""" |
|
Input: |
|
x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate |
|
input_lengths: Valid length for each sample # (B,) |
|
Output: |
|
dict: Contains the following key-value pairs |
|
"zq": Quantized embeddings # (B, D, T) |
|
"codes": Quantization codes # (nq, B, T) |
|
"codes_lengths": Quantization code lengths # (B,) |
|
""" |
|
list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)] |
|
features = self.feature_extractor( |
|
list_x, |
|
sampling_rate=self.input_sample_rate, |
|
return_tensors="pt", |
|
return_attention_mask=True |
|
) |
|
input_mel = features['input_features'].to(x.device).to(x.dtype) |
|
audio_attention_mask = features['attention_mask'].to(x.device) |
|
|
|
|
|
mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() |
|
|
|
|
|
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) |
|
|
|
semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) |
|
|
|
|
|
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) |
|
|
|
|
|
concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) |
|
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length |
|
|
|
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) |
|
|
|
downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) |
|
|
|
zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) |
|
|
|
return { |
|
"zq": zq, |
|
"codes": codes, |
|
"codes_lengths": quantizer_output_length |
|
} |
|
|
|
@torch.inference_mode() |
|
def inference_detokenize(self, codes, codes_lengths): |
|
""" |
|
Input: |
|
codes: Quantization codes # (nq, B, T) |
|
codes_lengths: Quantization code lengths for each sample # (B,) |
|
Output: |
|
dict: Contains the following key-value pairs |
|
"y": Synthesized audio waveform # (B, 1, T) |
|
"output_length": Output lengths # (B,) |
|
""" |
|
zq = self.quantizer.decode_codes(codes) |
|
|
|
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) |
|
|
|
|
|
upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) |
|
|
|
acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) |
|
|
|
y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) |
|
|
|
return { |
|
"y": y, |
|
"output_length": vocos_output_length, |
|
} |
|
|
|
@torch.inference_mode() |
|
def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")): |
|
""" |
|
Input: |
|
wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,) |
|
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output |
|
Output: |
|
dict: Contains the following key-value pairs |
|
"codes_list": List of quantization codes # B * (nq, T) |
|
""" |
|
duration_seconds = 30 - overlap_seconds |
|
chunk_size = int(30 * self.input_sample_rate) |
|
duration_size = int(duration_seconds * self.input_sample_rate) |
|
code_duration_length = duration_size // self.encoder_downsample_rate |
|
|
|
|
|
max_length = max(len(wav) for wav in wav_list) |
|
batch_size = len(wav_list) |
|
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device) |
|
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
for i, wav in enumerate(wav_list): |
|
wav_tensor[i, 0, :len(wav)] = wav |
|
input_lengths[i] = len(wav) |
|
|
|
|
|
max_chunks = (max_length + duration_size - 1) // duration_size |
|
codes_list = [] |
|
|
|
|
|
for chunk_idx in range(max_chunks): |
|
start = chunk_idx * duration_size |
|
end = min(start + chunk_size, max_length) |
|
chunk = wav_tensor[:, :, start:end] |
|
chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) |
|
|
|
|
|
if chunk_lengths.max() == 0: |
|
continue |
|
|
|
|
|
result = self.inference_tokenize(chunk, chunk_lengths) |
|
chunk_codes = result["codes"] |
|
chunk_code_lengths = result["codes_lengths"] |
|
|
|
|
|
valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) |
|
valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype) |
|
for b in range(batch_size): |
|
if valid_code_lengths[b] > 0: |
|
valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] |
|
|
|
codes_list.append(valid_chunk_codes) |
|
|
|
|
|
if codes_list: |
|
codes_tensor = torch.cat(codes_list, dim=-1) |
|
codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] |
|
else: |
|
codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] |
|
|
|
return { |
|
"codes_list": codes_list |
|
} |
|
|
|
@torch.inference_mode() |
|
def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")): |
|
""" |
|
Input: |
|
codes_list: List of quantization codes # B * (nq, T) |
|
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output |
|
Output: |
|
dict: Contains the following key-value pairs |
|
"syn_wav_list": List of synthesized audio waveforms # B * (T,) |
|
""" |
|
duration_seconds = 30 - overlap_seconds |
|
chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) |
|
duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) |
|
duration_wav_length = duration_code_length * self.decoder_upsample_rate |
|
|
|
|
|
max_code_length = max(codes.shape[-1] for codes in codes_list) |
|
batch_size = len(codes_list) |
|
codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long) |
|
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device) |
|
for i, codes in enumerate(codes_list): |
|
codes_tensor[:, i, :codes.shape[-1]] = codes.to(device) |
|
code_lengths[i] = codes.shape[-1] |
|
|
|
|
|
max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length |
|
wav_list = [] |
|
|
|
|
|
for chunk_idx in range(max_chunks): |
|
start = chunk_idx * duration_code_length |
|
end = min(start + chunk_code_length, max_code_length) |
|
chunk_codes = codes_tensor[:, :, start:end] |
|
chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) |
|
|
|
|
|
if chunk_code_lengths.max() == 0: |
|
continue |
|
|
|
|
|
result = self.inference_detokenize(chunk_codes, chunk_code_lengths) |
|
chunk_wav = result["y"] |
|
chunk_wav_lengths = result["output_length"] |
|
|
|
|
|
valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) |
|
valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device) |
|
for b in range(batch_size): |
|
if valid_wav_lengths[b] > 0: |
|
valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] |
|
|
|
wav_list.append(valid_chunk_wav) |
|
|
|
|
|
if wav_list: |
|
wav_tensor = torch.cat(wav_list, dim=-1) |
|
syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] |
|
else: |
|
syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] |
|
|
|
return { |
|
"syn_wav_list": syn_wav_list |
|
} |
|
|
|
@classmethod |
|
def load_from_checkpoint(cls, config_path: str, ckpt_path: str): |
|
|
|
logging.info(f"Loading model from {config_path} and {ckpt_path}") |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
config = yaml.safe_load(f) |
|
|
|
|
|
model = cls(config['generator_params']) |
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location='cpu') |
|
|
|
|
|
if 'generator' in checkpoint: |
|
model.load_state_dict(checkpoint['generator']) |
|
else: |
|
model.load_state_dict(checkpoint) |
|
|
|
return model |