yhzx233's picture
feat: app.py
ea174b0
# -*- coding: utf-8 -*-
import yaml
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nn.feature_extractor import MelFeatureExtractor
from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos
from .nn.quantizer import ResidualVQ
class XY_Tokenizer(nn.Module):
def __init__(self, generator_params):
super().__init__()
# Basic parameters
self.input_sample_rate = generator_params['input_sample_rate']
self.output_sample_rate = generator_params['output_sample_rate']
self.encoder_downsample_rate = 1280
self.decoder_upsample_rate = 1920
self.code_dim = generator_params['quantizer_kwargs']['input_dim']
## Codec part
## Semantic channel
self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs'])
self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs'])
## Acoustic channel
self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs'])
## Semantic & acoustic shared parameters
self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs'])
self.downsample = ResidualDownConv(**generator_params['downsample_kwargs'])
self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs'])
self.nq = generator_params['quantizer_kwargs']['num_quantizers']
self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs'])
## Acoustic channel
self.upsample = UpConv(**generator_params['upsample_kwargs'])
self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs'])
self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
## Feature extractor
self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs'])
@torch.inference_mode()
def inference_tokenize(self, x, input_lengths):
"""
Input:
x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
input_lengths: Valid length for each sample # (B,)
Output:
dict: Contains the following key-value pairs
"zq": Quantized embeddings # (B, D, T)
"codes": Quantization codes # (nq, B, T)
"codes_lengths": Quantization code lengths # (B,)
"""
list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)]
features = self.feature_extractor(
list_x,
sampling_rate=self.input_sample_rate,
return_tensors="pt",
return_attention_mask=True
)
input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000)
audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000)
# Get batch size and sequence length of the input
mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
# Semantic channel
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz
# Acoustic channel
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
# Semantic & acoustic mixing
concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T)
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz
downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz
zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
return {
"zq": zq, # (B, D, T)
"codes": codes, # (nq, B, T)
"codes_lengths": quantizer_output_length # (B,)
}
@torch.inference_mode()
def inference_detokenize(self, codes, codes_lengths):
"""
Input:
codes: Quantization codes # (nq, B, T)
codes_lengths: Quantization code lengths for each sample # (B,)
Output:
dict: Contains the following key-value pairs
"y": Synthesized audio waveform # (B, 1, T)
"output_length": Output lengths # (B,)
"""
zq = self.quantizer.decode_codes(codes) # (B, D, T)
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz
# Acoustic channel
upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz
acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz
y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz
return {
"y": y, # (B, 1, T)
"output_length": vocos_output_length, # (B,)
}
@torch.inference_mode()
def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")):
"""
Input:
wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
Output:
dict: Contains the following key-value pairs
"codes_list": List of quantization codes # B * (nq, T)
"""
duration_seconds = 30 - overlap_seconds
chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk
code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk
# Get maximum waveform length
max_length = max(len(wav) for wav in wav_list)
batch_size = len(wav_list)
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
for i, wav in enumerate(wav_list):
wav_tensor[i, 0, :len(wav)] = wav
input_lengths[i] = len(wav) # (B,)
# Calculate number of chunks needed
max_chunks = (max_length + duration_size - 1) // duration_size
codes_list = []
# Process the entire batch in chunks
for chunk_idx in range(max_chunks):
start = chunk_idx * duration_size
end = min(start + chunk_size, max_length)
chunk = wav_tensor[:, :, start:end] # (B, 1, T')
chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
# Skip empty chunks
if chunk_lengths.max() == 0:
continue
# Encode
result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
chunk_codes = result["codes"] # (nq, B, T')
chunk_code_lengths = result["codes_lengths"] # (B,)
# Extract valid portion
valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,)
valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype)
for b in range(batch_size):
if valid_code_lengths[b] > 0:
valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length)
codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
# Concatenate all chunks
if codes_list:
codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T)
else:
codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0)
return {
"codes_list": codes_list # B * (nq, T)
}
@torch.inference_mode()
def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")):
"""
Input:
codes_list: List of quantization codes # B * (nq, T)
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
Output:
dict: Contains the following key-value pairs
"syn_wav_list": List of synthesized audio waveforms # B * (T,)
"""
duration_seconds = 30 - overlap_seconds
chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk
duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk
duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk
# Get maximum code length
max_code_length = max(codes.shape[-1] for codes in codes_list)
batch_size = len(codes_list)
codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long)
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
for i, codes in enumerate(codes_list):
codes_tensor[:, i, :codes.shape[-1]] = codes.to(device)
code_lengths[i] = codes.shape[-1] # (B,)
# Calculate number of chunks needed
max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
wav_list = []
# Process the entire batch in chunks
for chunk_idx in range(max_chunks):
start = chunk_idx * duration_code_length
end = min(start + chunk_code_length, max_code_length)
chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
# Skip empty chunks
if chunk_code_lengths.max() == 0:
continue
# Decode
result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
chunk_wav = result["y"] # (B, 1, T')
chunk_wav_lengths = result["output_length"] # (B,)
# Extract valid portion
valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device)
for b in range(batch_size):
if valid_wav_lengths[b] > 0:
valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
# Concatenate all chunks
if wav_list:
wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,)
else:
syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,)
return {
"syn_wav_list": syn_wav_list # B * (T,)
}
@classmethod
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
# Load model from configuration file and checkpoint
logging.info(f"Loading model from {config_path} and {ckpt_path}")
# Load configuration
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
# Create model instance
model = cls(config['generator_params'])
# Load checkpoint
checkpoint = torch.load(ckpt_path, map_location='cpu')
# Check if checkpoint contains 'generator' key
if 'generator' in checkpoint:
model.load_state_dict(checkpoint['generator'])
else:
model.load_state_dict(checkpoint)
return model