Spaces:
Running
on
Zero
Running
on
Zero
mossttsd-space (#2)
Browse files- update (dbd498f5b6a8eb7aff9ea559070143b6f55c6315)
- update2 (8008b4069c3a9980a0137258cf9aab4f866c1d98)
Co-authored-by: zyq <rulerman@users.noreply.huggingface.co>
- XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml +114 -0
- XY_Tokenizer/xy_tokenizer/model.py +242 -150
- app.py +5 -4
- generation_utils.py +478 -165
XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_params:
|
| 2 |
+
input_sample_rate: 16000
|
| 3 |
+
output_sample_rate: 32000
|
| 4 |
+
encoder_downsample_rate: 1280
|
| 5 |
+
decoder_upsample_rate: 2560
|
| 6 |
+
|
| 7 |
+
feature_extractor_kwargs:
|
| 8 |
+
chunk_length: 30
|
| 9 |
+
feature_size: 80
|
| 10 |
+
hop_length: 160
|
| 11 |
+
n_fft: 400
|
| 12 |
+
n_samples: 480000
|
| 13 |
+
nb_max_frames: 3000
|
| 14 |
+
padding_side: right
|
| 15 |
+
padding_value: 0.0
|
| 16 |
+
return_attention_mask: false
|
| 17 |
+
sampling_rate: 16000
|
| 18 |
+
|
| 19 |
+
# Codec / model architecture (inference required)
|
| 20 |
+
semantic_encoder_kwargs: # 100hz -> 50hz
|
| 21 |
+
num_mel_bins: 80
|
| 22 |
+
sampling_rate: 16000
|
| 23 |
+
hop_length: 160
|
| 24 |
+
stride_size: 2
|
| 25 |
+
kernel_size: 3
|
| 26 |
+
d_model: 768
|
| 27 |
+
scale_embedding: false
|
| 28 |
+
max_audio_seconds: 30
|
| 29 |
+
encoder_layers: 12
|
| 30 |
+
encoder_attention_heads: 12
|
| 31 |
+
encoder_ffn_dim: 3072
|
| 32 |
+
activation_function: "gelu"
|
| 33 |
+
|
| 34 |
+
semantic_encoder_adapter_kwargs: # 50hz
|
| 35 |
+
input_dim: 768
|
| 36 |
+
output_dim: 768
|
| 37 |
+
d_model: 768
|
| 38 |
+
max_source_positions: 1500
|
| 39 |
+
encoder_layers: 4
|
| 40 |
+
encoder_attention_heads: 12
|
| 41 |
+
encoder_ffn_dim: 3072
|
| 42 |
+
|
| 43 |
+
acoustic_encoder_kwargs: # 100hz -> 50hz
|
| 44 |
+
num_mel_bins: 80
|
| 45 |
+
sampling_rate: 16000
|
| 46 |
+
hop_length: 160
|
| 47 |
+
stride_size: 2
|
| 48 |
+
kernel_size: 3
|
| 49 |
+
d_model: 768
|
| 50 |
+
scale_embedding: false
|
| 51 |
+
max_audio_seconds: 30
|
| 52 |
+
encoder_layers: 12
|
| 53 |
+
encoder_attention_heads: 12
|
| 54 |
+
encoder_ffn_dim: 3072
|
| 55 |
+
activation_function: "gelu"
|
| 56 |
+
|
| 57 |
+
pre_rvq_adapter_kwargs: # 50hz
|
| 58 |
+
input_dim: 1536
|
| 59 |
+
output_dim: 768
|
| 60 |
+
d_model: 768
|
| 61 |
+
max_source_positions: 1500
|
| 62 |
+
encoder_layers: 4
|
| 63 |
+
encoder_attention_heads: 12
|
| 64 |
+
encoder_ffn_dim: 3072
|
| 65 |
+
|
| 66 |
+
downsample_kwargs: # 50hz -> 12.5hz
|
| 67 |
+
d_model: 768
|
| 68 |
+
avg_pooler: 4
|
| 69 |
+
|
| 70 |
+
quantizer_kwargs: # 12.5hz
|
| 71 |
+
input_dim: 3072
|
| 72 |
+
rvq_dim: 512
|
| 73 |
+
output_dim: 3072
|
| 74 |
+
num_quantizers: 8
|
| 75 |
+
codebook_size: 1024
|
| 76 |
+
codebook_dim: 512
|
| 77 |
+
quantizer_dropout: 0.0
|
| 78 |
+
commitment: 1
|
| 79 |
+
|
| 80 |
+
post_rvq_adapter_kwargs: # 12.5hz
|
| 81 |
+
input_dim: 3072
|
| 82 |
+
output_dim: 3072
|
| 83 |
+
d_model: 768
|
| 84 |
+
max_source_positions: 375
|
| 85 |
+
encoder_layers: 4
|
| 86 |
+
encoder_attention_heads: 12
|
| 87 |
+
encoder_ffn_dim: 3072
|
| 88 |
+
|
| 89 |
+
upsample_kwargs: # 12.5hz -> 50hz
|
| 90 |
+
d_model: 768
|
| 91 |
+
stride: 4
|
| 92 |
+
|
| 93 |
+
acoustic_decoder_kwargs: # 50hz -> 100hz
|
| 94 |
+
num_mel_bins: 80
|
| 95 |
+
sampling_rate: 16000
|
| 96 |
+
hop_length: 160
|
| 97 |
+
stride_size: 2
|
| 98 |
+
kernel_size: 3
|
| 99 |
+
d_model: 768
|
| 100 |
+
scale_embedding: false
|
| 101 |
+
max_audio_seconds: 30
|
| 102 |
+
decoder_layers: 12
|
| 103 |
+
decoder_attention_heads: 12
|
| 104 |
+
decoder_ffn_dim: 3072
|
| 105 |
+
activation_function: "gelu"
|
| 106 |
+
|
| 107 |
+
vocos_kwargs: # 100hz -> 32khz
|
| 108 |
+
input_channels: 80
|
| 109 |
+
dim: 512
|
| 110 |
+
intermediate_dim: 4096
|
| 111 |
+
num_layers: 30
|
| 112 |
+
n_fft: 1280
|
| 113 |
+
hop_size: 320
|
| 114 |
+
padding: "same"
|
XY_Tokenizer/xy_tokenizer/model.py
CHANGED
|
@@ -1,146 +1,198 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
-
import yaml
|
| 3 |
import logging
|
|
|
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
-
|
| 8 |
|
| 9 |
from .nn.feature_extractor import MelFeatureExtractor
|
| 10 |
-
from .nn.modules import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from .nn.quantizer import ResidualVQ
|
| 12 |
|
|
|
|
| 13 |
class XY_Tokenizer(nn.Module):
|
| 14 |
def __init__(self, generator_params):
|
| 15 |
super().__init__()
|
| 16 |
# Basic parameters
|
| 17 |
-
self.input_sample_rate = generator_params[
|
| 18 |
-
self.output_sample_rate = generator_params[
|
| 19 |
-
|
| 20 |
-
self.encoder_downsample_rate =
|
| 21 |
-
self.decoder_upsample_rate =
|
| 22 |
-
self.code_dim = generator_params[
|
| 23 |
-
|
| 24 |
## Codec part
|
| 25 |
|
| 26 |
## Semantic channel
|
| 27 |
-
self.semantic_encoder = OmniAudioEncoder(
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
## Acoustic channel
|
| 32 |
-
self.acoustic_encoder = OmniAudioEncoder(
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
## Semantic & acoustic shared parameters
|
| 35 |
-
self.pre_rvq_adapter = Transformer(**generator_params[
|
| 36 |
-
|
| 37 |
-
self.downsample = ResidualDownConv(**generator_params[
|
| 38 |
-
|
| 39 |
-
self.quantizer = ResidualVQ(**generator_params[
|
| 40 |
-
self.nq = generator_params[
|
| 41 |
-
|
| 42 |
-
self.post_rvq_adapter = Transformer(
|
| 43 |
-
|
|
|
|
|
|
|
| 44 |
## Acoustic channel
|
| 45 |
-
self.upsample = UpConv(**generator_params[
|
| 46 |
|
| 47 |
-
self.acoustic_decoder = OmniAudioDecoder(
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
self.enhanced_vocos = Vocos(**generator_params[
|
| 50 |
|
| 51 |
## Feature extractor
|
| 52 |
-
|
|
|
|
| 53 |
|
| 54 |
@torch.inference_mode()
|
| 55 |
def inference_tokenize(self, x, input_lengths):
|
| 56 |
"""
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
"""
|
| 66 |
-
list_x = [
|
|
|
|
|
|
|
|
|
|
| 67 |
features = self.feature_extractor(
|
| 68 |
list_x,
|
| 69 |
sampling_rate=self.input_sample_rate,
|
| 70 |
return_tensors="pt",
|
| 71 |
-
return_attention_mask=True
|
| 72 |
)
|
| 73 |
-
input_mel = features[
|
| 74 |
-
audio_attention_mask = features[
|
| 75 |
-
|
| 76 |
# Get batch size and sequence length of the input
|
| 77 |
-
mel_output_length = torch.sum(audio_attention_mask, dim=-1).long()
|
| 78 |
-
|
| 79 |
# Semantic channel
|
| 80 |
-
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
# Acoustic channel
|
| 85 |
-
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(
|
| 86 |
-
|
|
|
|
|
|
|
| 87 |
# Semantic & acoustic mixing
|
| 88 |
-
concated_semantic_acoustic_channel = torch.concat(
|
|
|
|
|
|
|
| 89 |
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
|
| 90 |
-
|
| 91 |
-
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz
|
| 92 |
-
|
| 93 |
-
downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz
|
| 94 |
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
| 97 |
return {
|
| 98 |
-
"zq": zq,
|
| 99 |
-
"codes": codes,
|
| 100 |
-
"codes_lengths": quantizer_output_length
|
| 101 |
}
|
| 102 |
-
|
| 103 |
-
@torch.inference_mode()
|
| 104 |
def inference_detokenize(self, codes, codes_lengths):
|
| 105 |
"""
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
"""
|
| 114 |
-
zq = self.quantizer.decode_codes(codes)
|
| 115 |
-
|
| 116 |
-
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(
|
| 117 |
-
|
| 118 |
-
#
|
| 119 |
-
upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz
|
| 124 |
-
|
| 125 |
return {
|
| 126 |
-
"y": y,
|
| 127 |
-
"output_length": vocos_output_length,
|
| 128 |
}
|
| 129 |
-
|
| 130 |
@torch.inference_mode()
|
| 131 |
-
def encode(self, wav_list, overlap_seconds=10
|
| 132 |
"""
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
"""
|
|
|
|
| 140 |
duration_seconds = 30 - overlap_seconds
|
| 141 |
-
chunk_size = int(30 * self.input_sample_rate)
|
| 142 |
-
duration_size = int(
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
# Get maximum waveform length
|
| 146 |
max_length = max(len(wav) for wav in wav_list)
|
|
@@ -148,8 +200,8 @@ class XY_Tokenizer(nn.Module):
|
|
| 148 |
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
|
| 149 |
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 150 |
for i, wav in enumerate(wav_list):
|
| 151 |
-
wav_tensor[i, 0, :len(wav)] = wav
|
| 152 |
-
input_lengths[i] = len(wav)
|
| 153 |
|
| 154 |
# Calculate number of chunks needed
|
| 155 |
max_chunks = (max_length + duration_size - 1) // duration_size
|
|
@@ -159,121 +211,161 @@ class XY_Tokenizer(nn.Module):
|
|
| 159 |
for chunk_idx in range(max_chunks):
|
| 160 |
start = chunk_idx * duration_size
|
| 161 |
end = min(start + chunk_size, max_length)
|
| 162 |
-
chunk = wav_tensor[:, :, start:end]
|
| 163 |
-
chunk_lengths = torch.clamp(input_lengths - start, 0, end - start)
|
| 164 |
|
| 165 |
# Skip empty chunks
|
| 166 |
if chunk_lengths.max() == 0:
|
| 167 |
continue
|
| 168 |
|
| 169 |
# Encode
|
| 170 |
-
result = self.inference_tokenize(
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
| 173 |
|
| 174 |
# Extract valid portion
|
| 175 |
-
valid_code_lengths = torch.clamp(
|
| 176 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
for b in range(batch_size):
|
| 178 |
if valid_code_lengths[b] > 0:
|
| 179 |
-
valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[
|
|
|
|
|
|
|
| 180 |
|
| 181 |
-
codes_list.append(valid_chunk_codes)
|
| 182 |
|
| 183 |
# Concatenate all chunks
|
| 184 |
if codes_list:
|
| 185 |
-
codes_tensor = torch.cat(codes_list, dim=-1)
|
| 186 |
-
codes_list = [
|
|
|
|
|
|
|
|
|
|
| 187 |
else:
|
| 188 |
-
codes_list = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
-
return {
|
| 191 |
-
"codes_list": codes_list # B * (nq, T)
|
| 192 |
-
}
|
| 193 |
-
|
| 194 |
@torch.inference_mode()
|
| 195 |
-
def decode(self, codes_list, overlap_seconds=10
|
| 196 |
"""
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
"""
|
|
|
|
| 204 |
duration_seconds = 30 - overlap_seconds
|
| 205 |
-
chunk_code_length = int(
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
# Get maximum code length
|
| 210 |
max_code_length = max(codes.shape[-1] for codes in codes_list)
|
| 211 |
batch_size = len(codes_list)
|
| 212 |
-
codes_tensor = torch.zeros(
|
|
|
|
|
|
|
| 213 |
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 214 |
for i, codes in enumerate(codes_list):
|
| 215 |
-
codes_tensor[:, i, :codes.shape[-1]] = codes.to(device)
|
| 216 |
-
code_lengths[i] = codes.shape[-1]
|
| 217 |
|
| 218 |
# Calculate number of chunks needed
|
| 219 |
-
max_chunks = (
|
|
|
|
|
|
|
| 220 |
wav_list = []
|
| 221 |
|
| 222 |
# Process the entire batch in chunks
|
| 223 |
for chunk_idx in range(max_chunks):
|
| 224 |
start = chunk_idx * duration_code_length
|
| 225 |
end = min(start + chunk_code_length, max_code_length)
|
| 226 |
-
chunk_codes = codes_tensor[:, :, start:end]
|
| 227 |
-
chunk_code_lengths = torch.clamp(
|
|
|
|
|
|
|
| 228 |
|
| 229 |
# Skip empty chunks
|
| 230 |
if chunk_code_lengths.max() == 0:
|
| 231 |
continue
|
| 232 |
|
| 233 |
# Decode
|
| 234 |
-
result = self.inference_detokenize(
|
| 235 |
-
|
| 236 |
-
|
|
|
|
|
|
|
| 237 |
|
| 238 |
# Extract valid portion
|
| 239 |
-
valid_wav_lengths = torch.clamp(
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
for b in range(batch_size):
|
| 242 |
if valid_wav_lengths[b] > 0:
|
| 243 |
-
valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[
|
|
|
|
|
|
|
| 244 |
|
| 245 |
-
wav_list.append(valid_chunk_wav)
|
| 246 |
|
| 247 |
# Concatenate all chunks
|
| 248 |
if wav_list:
|
| 249 |
-
wav_tensor = torch.cat(wav_list, dim=-1)
|
| 250 |
-
syn_wav_list = [
|
|
|
|
|
|
|
|
|
|
| 251 |
else:
|
| 252 |
-
syn_wav_list = [
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
}
|
| 257 |
-
|
| 258 |
@classmethod
|
| 259 |
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
|
| 260 |
# Load model from configuration file and checkpoint
|
| 261 |
logging.info(f"Loading model from {config_path} and {ckpt_path}")
|
| 262 |
-
|
| 263 |
# Load configuration
|
| 264 |
-
with open(config_path,
|
| 265 |
config = yaml.safe_load(f)
|
| 266 |
-
|
| 267 |
# Create model instance
|
| 268 |
-
model = cls(config[
|
| 269 |
-
|
| 270 |
# Load checkpoint
|
| 271 |
-
checkpoint = torch.load(ckpt_path, map_location=
|
| 272 |
-
|
| 273 |
# Check if checkpoint contains 'generator' key
|
| 274 |
-
if
|
| 275 |
-
model.load_state_dict(checkpoint[
|
| 276 |
else:
|
| 277 |
model.load_state_dict(checkpoint)
|
| 278 |
-
|
| 279 |
-
return model
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
|
|
|
| 2 |
import logging
|
| 3 |
+
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
+
import yaml
|
| 8 |
|
| 9 |
from .nn.feature_extractor import MelFeatureExtractor
|
| 10 |
+
from .nn.modules import (
|
| 11 |
+
OmniAudioDecoder,
|
| 12 |
+
OmniAudioEncoder,
|
| 13 |
+
ResidualDownConv,
|
| 14 |
+
Transformer,
|
| 15 |
+
UpConv,
|
| 16 |
+
Vocos,
|
| 17 |
+
)
|
| 18 |
from .nn.quantizer import ResidualVQ
|
| 19 |
|
| 20 |
+
|
| 21 |
class XY_Tokenizer(nn.Module):
|
| 22 |
def __init__(self, generator_params):
|
| 23 |
super().__init__()
|
| 24 |
# Basic parameters
|
| 25 |
+
self.input_sample_rate = generator_params["input_sample_rate"]
|
| 26 |
+
self.output_sample_rate = generator_params["output_sample_rate"]
|
| 27 |
+
|
| 28 |
+
self.encoder_downsample_rate = generator_params["encoder_downsample_rate"]
|
| 29 |
+
self.decoder_upsample_rate = generator_params["decoder_upsample_rate"]
|
| 30 |
+
self.code_dim = generator_params["quantizer_kwargs"]["input_dim"]
|
| 31 |
+
|
| 32 |
## Codec part
|
| 33 |
|
| 34 |
## Semantic channel
|
| 35 |
+
self.semantic_encoder = OmniAudioEncoder(
|
| 36 |
+
**generator_params["semantic_encoder_kwargs"]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.semantic_encoder_adapter = Transformer(
|
| 40 |
+
**generator_params["semantic_encoder_adapter_kwargs"]
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
## Acoustic channel
|
| 44 |
+
self.acoustic_encoder = OmniAudioEncoder(
|
| 45 |
+
**generator_params["acoustic_encoder_kwargs"]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
## Semantic & acoustic shared parameters
|
| 49 |
+
self.pre_rvq_adapter = Transformer(**generator_params["pre_rvq_adapter_kwargs"])
|
| 50 |
+
|
| 51 |
+
self.downsample = ResidualDownConv(**generator_params["downsample_kwargs"])
|
| 52 |
+
|
| 53 |
+
self.quantizer = ResidualVQ(**generator_params["quantizer_kwargs"])
|
| 54 |
+
self.nq = generator_params["quantizer_kwargs"]["num_quantizers"]
|
| 55 |
+
|
| 56 |
+
self.post_rvq_adapter = Transformer(
|
| 57 |
+
**generator_params["post_rvq_adapter_kwargs"]
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
## Acoustic channel
|
| 61 |
+
self.upsample = UpConv(**generator_params["upsample_kwargs"])
|
| 62 |
|
| 63 |
+
self.acoustic_decoder = OmniAudioDecoder(
|
| 64 |
+
**generator_params["acoustic_decoder_kwargs"]
|
| 65 |
+
)
|
| 66 |
|
| 67 |
+
self.enhanced_vocos = Vocos(**generator_params["vocos_kwargs"])
|
| 68 |
|
| 69 |
## Feature extractor
|
| 70 |
+
fe_kwargs = generator_params.get("feature_extractor_kwargs", {})
|
| 71 |
+
self.feature_extractor = MelFeatureExtractor(**fe_kwargs)
|
| 72 |
|
| 73 |
@torch.inference_mode()
|
| 74 |
def inference_tokenize(self, x, input_lengths):
|
| 75 |
"""
|
| 76 |
+
Input:
|
| 77 |
+
x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
|
| 78 |
+
input_lengths: Valid length for each sample # (B,)
|
| 79 |
+
Output:
|
| 80 |
+
dict: Contains the following key-value pairs
|
| 81 |
+
"zq": Quantized embeddings # (B, D, T)
|
| 82 |
+
"codes": Quantization codes # (nq, B, T)
|
| 83 |
+
"codes_lengths": Quantization code lengths # (B,)
|
| 84 |
"""
|
| 85 |
+
list_x = [
|
| 86 |
+
xi[:, :x_len].reshape(-1).cpu().numpy()
|
| 87 |
+
for xi, x_len in zip(x, input_lengths)
|
| 88 |
+
]
|
| 89 |
features = self.feature_extractor(
|
| 90 |
list_x,
|
| 91 |
sampling_rate=self.input_sample_rate,
|
| 92 |
return_tensors="pt",
|
| 93 |
+
return_attention_mask=True,
|
| 94 |
)
|
| 95 |
+
input_mel = features["input_features"].to(x.device).to(x.dtype) # (B, D, 3000)
|
| 96 |
+
audio_attention_mask = features["attention_mask"].to(x.device) # (B, 3000)
|
| 97 |
+
|
| 98 |
# Get batch size and sequence length of the input
|
| 99 |
+
mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
|
| 100 |
+
|
| 101 |
# Semantic channel
|
| 102 |
+
semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(
|
| 103 |
+
input_mel, mel_output_length
|
| 104 |
+
) # (B, D, T), 100hz -> 50hz
|
| 105 |
+
|
| 106 |
+
semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = (
|
| 107 |
+
self.semantic_encoder_adapter(
|
| 108 |
+
semantic_encoder_output, semantic_encoder_output_length
|
| 109 |
+
)
|
| 110 |
+
) # (B, D, T), 50hz
|
| 111 |
+
|
| 112 |
# Acoustic channel
|
| 113 |
+
acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(
|
| 114 |
+
input_mel, mel_output_length
|
| 115 |
+
) # (B, D, T), 100hz -> 50hz
|
| 116 |
+
|
| 117 |
# Semantic & acoustic mixing
|
| 118 |
+
concated_semantic_acoustic_channel = torch.concat(
|
| 119 |
+
[semantic_encoder_adapter_output, acoustic_encoder_output], dim=1
|
| 120 |
+
) # (B, D, T)
|
| 121 |
concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(
|
| 124 |
+
concated_semantic_acoustic_channel,
|
| 125 |
+
concated_semantic_acoustic_channel_length,
|
| 126 |
+
) # (B, D, T), 50hz
|
| 127 |
+
|
| 128 |
+
downsample_output, downsample_output_length = self.downsample(
|
| 129 |
+
pre_rvq_adapter_output, pre_rvq_adapter_output_length
|
| 130 |
+
) # (B, D, T), 50hz -> 12.5hz
|
| 131 |
+
|
| 132 |
+
zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(
|
| 133 |
+
downsample_output, downsample_output_length
|
| 134 |
+
) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
|
| 135 |
|
| 136 |
return {
|
| 137 |
+
"zq": zq, # (B, D, T)
|
| 138 |
+
"codes": codes, # (nq, B, T)
|
| 139 |
+
"codes_lengths": quantizer_output_length, # (B,)
|
| 140 |
}
|
| 141 |
+
|
| 142 |
+
@torch.inference_mode()
|
| 143 |
def inference_detokenize(self, codes, codes_lengths):
|
| 144 |
"""
|
| 145 |
+
Input:
|
| 146 |
+
codes: Quantization codes # (nq, B, T)
|
| 147 |
+
codes_lengths: Quantization code lengths for each sample # (B,)
|
| 148 |
+
Output:
|
| 149 |
+
dict: Contains the following key-value pairs
|
| 150 |
+
"y": Synthesized audio waveform # (B, 1, T)
|
| 151 |
+
"output_length": Output lengths # (B,)
|
| 152 |
"""
|
| 153 |
+
zq = self.quantizer.decode_codes(codes) # (B, D, T)
|
| 154 |
+
|
| 155 |
+
post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(
|
| 156 |
+
zq, codes_lengths
|
| 157 |
+
) # (B, D, T), 12.5hz
|
|
|
|
| 158 |
|
| 159 |
+
# Acoustic channel
|
| 160 |
+
upsample_output, upsample_output_length = self.upsample(
|
| 161 |
+
post_rvq_adapter_output, post_rvq_adapter_output_length
|
| 162 |
+
) # (B, D, T), 12.5hz -> 50hz
|
| 163 |
+
|
| 164 |
+
acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(
|
| 165 |
+
upsample_output, upsample_output_length
|
| 166 |
+
) # (B, D, T), 50hz -> 100hz
|
| 167 |
+
|
| 168 |
+
y, vocos_output_length = self.enhanced_vocos(
|
| 169 |
+
acoustic_decoder_output, acoustic_decoder_output_length
|
| 170 |
+
) # (B, 1, T), 100hz -> 16khz
|
| 171 |
|
|
|
|
|
|
|
| 172 |
return {
|
| 173 |
+
"y": y, # (B, 1, T)
|
| 174 |
+
"output_length": vocos_output_length, # (B,)
|
| 175 |
}
|
| 176 |
+
|
| 177 |
@torch.inference_mode()
|
| 178 |
+
def encode(self, wav_list, overlap_seconds=10):
|
| 179 |
"""
|
| 180 |
+
Input:
|
| 181 |
+
wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
|
| 182 |
+
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
|
| 183 |
+
Output:
|
| 184 |
+
dict: Contains the following key-value pairs
|
| 185 |
+
"codes_list": List of quantization codes # B * (nq, T)
|
| 186 |
"""
|
| 187 |
+
device = wav_list[0].device
|
| 188 |
duration_seconds = 30 - overlap_seconds
|
| 189 |
+
chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
|
| 190 |
+
duration_size = int(
|
| 191 |
+
duration_seconds * self.input_sample_rate
|
| 192 |
+
) # Valid output samples per chunk
|
| 193 |
+
code_duration_length = (
|
| 194 |
+
duration_size // self.encoder_downsample_rate
|
| 195 |
+
) # Valid code length per chunk
|
| 196 |
|
| 197 |
# Get maximum waveform length
|
| 198 |
max_length = max(len(wav) for wav in wav_list)
|
|
|
|
| 200 |
wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
|
| 201 |
input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 202 |
for i, wav in enumerate(wav_list):
|
| 203 |
+
wav_tensor[i, 0, : len(wav)] = wav
|
| 204 |
+
input_lengths[i] = len(wav) # (B,)
|
| 205 |
|
| 206 |
# Calculate number of chunks needed
|
| 207 |
max_chunks = (max_length + duration_size - 1) // duration_size
|
|
|
|
| 211 |
for chunk_idx in range(max_chunks):
|
| 212 |
start = chunk_idx * duration_size
|
| 213 |
end = min(start + chunk_size, max_length)
|
| 214 |
+
chunk = wav_tensor[:, :, start:end] # (B, 1, T')
|
| 215 |
+
chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
|
| 216 |
|
| 217 |
# Skip empty chunks
|
| 218 |
if chunk_lengths.max() == 0:
|
| 219 |
continue
|
| 220 |
|
| 221 |
# Encode
|
| 222 |
+
result = self.inference_tokenize(
|
| 223 |
+
chunk, chunk_lengths
|
| 224 |
+
) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
|
| 225 |
+
chunk_codes = result["codes"] # (nq, B, T')
|
| 226 |
+
chunk_code_lengths = result["codes_lengths"] # (B,)
|
| 227 |
|
| 228 |
# Extract valid portion
|
| 229 |
+
valid_code_lengths = torch.clamp(
|
| 230 |
+
chunk_code_lengths, 0, code_duration_length
|
| 231 |
+
) # (B,)
|
| 232 |
+
valid_chunk_codes = torch.zeros(
|
| 233 |
+
self.nq,
|
| 234 |
+
batch_size,
|
| 235 |
+
code_duration_length,
|
| 236 |
+
device=device,
|
| 237 |
+
dtype=chunk_codes.dtype,
|
| 238 |
+
)
|
| 239 |
for b in range(batch_size):
|
| 240 |
if valid_code_lengths[b] > 0:
|
| 241 |
+
valid_chunk_codes[:, b, : valid_code_lengths[b]] = chunk_codes[
|
| 242 |
+
:, b, : valid_code_lengths[b]
|
| 243 |
+
] # (nq, B, valid_code_length)
|
| 244 |
|
| 245 |
+
codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
|
| 246 |
|
| 247 |
# Concatenate all chunks
|
| 248 |
if codes_list:
|
| 249 |
+
codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
|
| 250 |
+
codes_list = [
|
| 251 |
+
codes_tensor[:, i, : input_lengths[i] // self.encoder_downsample_rate]
|
| 252 |
+
for i in range(batch_size)
|
| 253 |
+
] # B * (nq, T)
|
| 254 |
else:
|
| 255 |
+
codes_list = [
|
| 256 |
+
torch.zeros(self.nq, 0, device=device, dtype=torch.long)
|
| 257 |
+
for _ in range(batch_size)
|
| 258 |
+
] # B * (nq, 0)
|
| 259 |
+
|
| 260 |
+
return {"codes_list": codes_list} # B * (nq, T)
|
| 261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
@torch.inference_mode()
|
| 263 |
+
def decode(self, codes_list, overlap_seconds=10):
|
| 264 |
"""
|
| 265 |
+
Input:
|
| 266 |
+
codes_list: List of quantization codes # B * (nq, T)
|
| 267 |
+
overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
|
| 268 |
+
Output:
|
| 269 |
+
dict: Contains the following key-value pairs
|
| 270 |
+
"syn_wav_list": List of synthesized audio waveforms # B * (T,)
|
| 271 |
"""
|
| 272 |
+
device = codes_list[0].device
|
| 273 |
duration_seconds = 30 - overlap_seconds
|
| 274 |
+
chunk_code_length = int(
|
| 275 |
+
30 * self.input_sample_rate // self.encoder_downsample_rate
|
| 276 |
+
) # Maximum code length per chunk
|
| 277 |
+
duration_code_length = int(
|
| 278 |
+
duration_seconds * self.input_sample_rate // self.encoder_downsample_rate
|
| 279 |
+
) # Valid code length per chunk
|
| 280 |
+
duration_wav_length = (
|
| 281 |
+
duration_code_length * self.decoder_upsample_rate
|
| 282 |
+
) # Valid waveform length per chunk
|
| 283 |
|
| 284 |
# Get maximum code length
|
| 285 |
max_code_length = max(codes.shape[-1] for codes in codes_list)
|
| 286 |
batch_size = len(codes_list)
|
| 287 |
+
codes_tensor = torch.zeros(
|
| 288 |
+
self.nq, batch_size, max_code_length, device=device, dtype=torch.long
|
| 289 |
+
)
|
| 290 |
code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
|
| 291 |
for i, codes in enumerate(codes_list):
|
| 292 |
+
codes_tensor[:, i, : codes.shape[-1]] = codes.to(device)
|
| 293 |
+
code_lengths[i] = codes.shape[-1] # (B,)
|
| 294 |
|
| 295 |
# Calculate number of chunks needed
|
| 296 |
+
max_chunks = (
|
| 297 |
+
max_code_length + duration_code_length - 1
|
| 298 |
+
) // duration_code_length
|
| 299 |
wav_list = []
|
| 300 |
|
| 301 |
# Process the entire batch in chunks
|
| 302 |
for chunk_idx in range(max_chunks):
|
| 303 |
start = chunk_idx * duration_code_length
|
| 304 |
end = min(start + chunk_code_length, max_code_length)
|
| 305 |
+
chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
|
| 306 |
+
chunk_code_lengths = torch.clamp(
|
| 307 |
+
code_lengths - start, 0, end - start
|
| 308 |
+
) # (B,)
|
| 309 |
|
| 310 |
# Skip empty chunks
|
| 311 |
if chunk_code_lengths.max() == 0:
|
| 312 |
continue
|
| 313 |
|
| 314 |
# Decode
|
| 315 |
+
result = self.inference_detokenize(
|
| 316 |
+
chunk_codes, chunk_code_lengths
|
| 317 |
+
) # {"y": (B, 1, T'), "output_length": (B,)}
|
| 318 |
+
chunk_wav = result["y"] # (B, 1, T')
|
| 319 |
+
chunk_wav_lengths = result["output_length"] # (B,)
|
| 320 |
|
| 321 |
# Extract valid portion
|
| 322 |
+
valid_wav_lengths = torch.clamp(
|
| 323 |
+
chunk_wav_lengths, 0, duration_wav_length
|
| 324 |
+
) # (B,)
|
| 325 |
+
valid_chunk_wav = torch.zeros(
|
| 326 |
+
batch_size, 1, duration_wav_length, device=device
|
| 327 |
+
)
|
| 328 |
for b in range(batch_size):
|
| 329 |
if valid_wav_lengths[b] > 0:
|
| 330 |
+
valid_chunk_wav[b, :, : valid_wav_lengths[b]] = chunk_wav[
|
| 331 |
+
b, :, : valid_wav_lengths[b]
|
| 332 |
+
] # (B, 1, valid_wav_length)
|
| 333 |
|
| 334 |
+
wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
|
| 335 |
|
| 336 |
# Concatenate all chunks
|
| 337 |
if wav_list:
|
| 338 |
+
wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
|
| 339 |
+
syn_wav_list = [
|
| 340 |
+
wav_tensor[i, 0, : code_lengths[i] * self.decoder_upsample_rate]
|
| 341 |
+
for i in range(batch_size)
|
| 342 |
+
] # B * (T,)
|
| 343 |
else:
|
| 344 |
+
syn_wav_list = [
|
| 345 |
+
torch.zeros(0, device=device) for _ in range(batch_size)
|
| 346 |
+
] # B * (0,)
|
| 347 |
+
|
| 348 |
+
return {"syn_wav_list": syn_wav_list} # B * (T,)
|
| 349 |
+
|
| 350 |
@classmethod
|
| 351 |
def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
|
| 352 |
# Load model from configuration file and checkpoint
|
| 353 |
logging.info(f"Loading model from {config_path} and {ckpt_path}")
|
| 354 |
+
|
| 355 |
# Load configuration
|
| 356 |
+
with open(config_path, "r") as f:
|
| 357 |
config = yaml.safe_load(f)
|
| 358 |
+
|
| 359 |
# Create model instance
|
| 360 |
+
model = cls(config["generator_params"])
|
| 361 |
+
|
| 362 |
# Load checkpoint
|
| 363 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 364 |
+
|
| 365 |
# Check if checkpoint contains 'generator' key
|
| 366 |
+
if "generator" in checkpoint:
|
| 367 |
+
model.load_state_dict(checkpoint["generator"])
|
| 368 |
else:
|
| 369 |
model.load_state_dict(checkpoint)
|
| 370 |
+
|
| 371 |
+
return model
|
app.py
CHANGED
|
@@ -131,15 +131,15 @@ LANGUAGES = {
|
|
| 131 |
# Model configuration
|
| 132 |
SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text."
|
| 133 |
MODEL_PATH = os.environ["MODEL_REPO_ID"]
|
| 134 |
-
SPT_CONFIG_PATH = "XY_Tokenizer/config/
|
| 135 |
# SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt"
|
| 136 |
MAX_CHANNELS = 8
|
| 137 |
|
| 138 |
from huggingface_hub import hf_hub_download
|
| 139 |
|
| 140 |
SPT_CHECKPOINT_PATH = hf_hub_download(
|
| 141 |
-
repo_id="
|
| 142 |
-
filename="
|
| 143 |
cache_dir="XY_Tokenizer/weights"
|
| 144 |
)
|
| 145 |
|
|
@@ -245,7 +245,8 @@ def process_single_audio_generation(
|
|
| 245 |
device=device,
|
| 246 |
system_prompt=SYSTEM_PROMPT,
|
| 247 |
start_idx=0,
|
| 248 |
-
use_normalize=use_normalize
|
|
|
|
| 249 |
)
|
| 250 |
|
| 251 |
# Check results
|
|
|
|
| 131 |
# Model configuration
|
| 132 |
SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text."
|
| 133 |
MODEL_PATH = os.environ["MODEL_REPO_ID"]
|
| 134 |
+
SPT_CONFIG_PATH = "XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml"
|
| 135 |
# SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt"
|
| 136 |
MAX_CHANNELS = 8
|
| 137 |
|
| 138 |
from huggingface_hub import hf_hub_download
|
| 139 |
|
| 140 |
SPT_CHECKPOINT_PATH = hf_hub_download(
|
| 141 |
+
repo_id="OpenMOSS-Team/MOSS_TTSD_tokenizer",
|
| 142 |
+
filename="MOSS_TTSD_tokenizer",
|
| 143 |
cache_dir="XY_Tokenizer/weights"
|
| 144 |
)
|
| 145 |
|
|
|
|
| 245 |
device=device,
|
| 246 |
system_prompt=SYSTEM_PROMPT,
|
| 247 |
start_idx=0,
|
| 248 |
+
use_normalize=use_normalize,
|
| 249 |
+
silence_duration=0.1,
|
| 250 |
)
|
| 251 |
|
| 252 |
# Check results
|
generation_utils.py
CHANGED
|
@@ -1,86 +1,181 @@
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
|
|
|
|
| 4 |
import torch
|
| 5 |
import torchaudio
|
| 6 |
-
import numpy as np
|
| 7 |
-
|
| 8 |
-
from transformers import AutoTokenizer
|
| 9 |
-
from modeling_asteroid import AsteroidTTSInstruct
|
| 10 |
-
from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
|
| 11 |
|
| 12 |
MAX_CHANNELS = 8
|
| 13 |
-
SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
|
| 14 |
|
| 15 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
-
spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
|
| 21 |
-
|
| 22 |
model.eval()
|
| 23 |
spt.eval()
|
| 24 |
return tokenizer, model, spt
|
| 25 |
|
| 26 |
|
| 27 |
def process_jsonl_item(item):
|
| 28 |
-
"""
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
text = item.get("text", "")
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
if isinstance(
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
else:
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2
|
| 58 |
else:
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
prompt_audio = {
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
if
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
"
|
| 76 |
-
|
| 77 |
-
"prompt_audio": prompt_audio
|
| 78 |
-
}
|
| 79 |
|
| 80 |
|
| 81 |
def load_audio_data(prompt_audio, target_sample_rate=16000):
|
| 82 |
"""Load audio data and return processed audio tensor
|
| 83 |
-
|
| 84 |
Args:
|
| 85 |
prompt_audio: Can be in the following formats:
|
| 86 |
- String: audio file path
|
|
@@ -89,10 +184,14 @@ def load_audio_data(prompt_audio, target_sample_rate=16000):
|
|
| 89 |
"""
|
| 90 |
if prompt_audio is None:
|
| 91 |
return None
|
| 92 |
-
|
| 93 |
try:
|
| 94 |
# Check if prompt_audio is a dictionary (containing speaker1 and speaker2)
|
| 95 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
# Process audio from both speakers separately
|
| 97 |
wav1, sr1 = _load_single_audio(prompt_audio["speaker1"])
|
| 98 |
wav2, sr2 = _load_single_audio(prompt_audio["speaker2"])
|
|
@@ -104,14 +203,14 @@ def load_audio_data(prompt_audio, target_sample_rate=16000):
|
|
| 104 |
# Single audio
|
| 105 |
wav, sr = _load_single_audio(prompt_audio)
|
| 106 |
# Resample to 16k
|
| 107 |
-
if sr != target_sample_rate:
|
| 108 |
wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
|
| 109 |
# Ensure mono channel
|
| 110 |
if wav.shape[0] > 1:
|
| 111 |
wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono
|
| 112 |
-
if len(wav.shape) == 1:
|
| 113 |
wav = wav.unsqueeze(0)
|
| 114 |
-
|
| 115 |
return wav
|
| 116 |
except Exception as e:
|
| 117 |
print(f"Error loading audio data: {e}")
|
|
@@ -120,10 +219,10 @@ def load_audio_data(prompt_audio, target_sample_rate=16000):
|
|
| 120 |
|
| 121 |
def _load_single_audio(audio_input):
|
| 122 |
"""Load single audio, supports file path or (wav, sr) tuple
|
| 123 |
-
|
| 124 |
Args:
|
| 125 |
audio_input: String (file path) or tuple (wav, sr)
|
| 126 |
-
|
| 127 |
Returns:
|
| 128 |
tuple: (wav, sr)
|
| 129 |
"""
|
|
@@ -150,8 +249,8 @@ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
|
|
| 150 |
wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono
|
| 151 |
if len(wav1.shape) == 1:
|
| 152 |
wav1 = wav1.unsqueeze(0)
|
| 153 |
-
|
| 154 |
-
# Process second audio
|
| 155 |
if sr2 != target_sample_rate:
|
| 156 |
wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate)
|
| 157 |
# Ensure mono channel
|
|
@@ -159,7 +258,7 @@ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
|
|
| 159 |
wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono
|
| 160 |
if len(wav2.shape) == 1:
|
| 161 |
wav2 = wav2.unsqueeze(0)
|
| 162 |
-
|
| 163 |
# Concatenate audio
|
| 164 |
merged_wav = torch.cat([wav1, wav2], dim=1)
|
| 165 |
return merged_wav
|
|
@@ -168,34 +267,48 @@ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
|
|
| 168 |
raise
|
| 169 |
|
| 170 |
|
| 171 |
-
def process_inputs(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>"
|
| 173 |
inputs1 = np.array(tokenizer.encode(seq))
|
| 174 |
input_ids = np.full((inputs1.shape[0], max_channels), pad_token)
|
| 175 |
input_ids[:, 0] = inputs1
|
| 176 |
-
|
| 177 |
if audio_data is not None:
|
| 178 |
try:
|
| 179 |
# audio_data should now be a processed audio tensor
|
| 180 |
wav = audio_data
|
| 181 |
-
|
| 182 |
# Add fixed 5-second silence at the end of audio (using 16k sample rate)
|
| 183 |
-
silence_samples = int(
|
| 184 |
silence = torch.zeros(wav.shape[0], silence_samples)
|
| 185 |
wav = torch.cat([wav, silence], dim=1)
|
| 186 |
-
|
| 187 |
with torch.no_grad():
|
| 188 |
# Use SPT encoding
|
| 189 |
encode_result = spt.encode([wav.squeeze().to(device)])
|
| 190 |
-
audio_token =
|
| 191 |
-
|
|
|
|
|
|
|
| 192 |
# similar to DAC encoding adjustment
|
| 193 |
-
audio_token[:, 0] =
|
|
|
|
|
|
|
| 194 |
input_ids = np.concatenate([input_ids, audio_token])
|
| 195 |
except Exception as e:
|
| 196 |
print(f"Error processing audio data: {e}")
|
| 197 |
raise
|
| 198 |
-
|
| 199 |
return input_ids
|
| 200 |
|
| 201 |
|
|
@@ -203,7 +316,9 @@ def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8):
|
|
| 203 |
seq_len = input_ids.shape[0]
|
| 204 |
new_seq_len = seq_len + max_channels - 1
|
| 205 |
shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64)
|
| 206 |
-
shifted_input_ids[:, 0] = np.full(
|
|
|
|
|
|
|
| 207 |
for i in range(max_channels):
|
| 208 |
shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i]
|
| 209 |
return shifted_input_ids
|
|
@@ -213,7 +328,7 @@ def rpadding(input_ids, channels, tokenizer):
|
|
| 213 |
attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids]
|
| 214 |
max_length = max(ids.shape[0] for ids in input_ids)
|
| 215 |
padded_input_ids, padded_attns = [], []
|
| 216 |
-
|
| 217 |
for ids, attn in zip(input_ids, attention_masks):
|
| 218 |
pad_len = max_length - ids.shape[0]
|
| 219 |
input_pad = np.full((pad_len, channels), 1024)
|
|
@@ -245,26 +360,23 @@ def normalize_text(text: str) -> str:
|
|
| 245 |
Normalize multi-speaker script.
|
| 246 |
|
| 247 |
1. Don't preserve line breaks.
|
| 248 |
-
2.
|
| 249 |
-
3. Remove decorative symbols:
|
| 250 |
-
4. Internal punctuation
|
| 251 |
5. Multiple 。 keep only the last one, others → ,。
|
| 252 |
6. Replace consecutive "哈" (>=2) with "(笑)".
|
| 253 |
7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
|
|
|
|
| 254 |
"""
|
| 255 |
# Replace [1], [2] etc. format with [S1], [S2] etc. format
|
| 256 |
-
text = re.sub(r
|
| 257 |
|
| 258 |
# Remove decorative characters
|
| 259 |
-
remove_chars = "【】《》()『』「」""
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
# Remove brackets for non-speaker tags (keep content, only remove brackets themselves)
|
| 263 |
-
text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text)
|
| 264 |
|
| 265 |
# Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
|
| 266 |
-
segments = re.split(r
|
| 267 |
-
|
| 268 |
|
| 269 |
for seg in segments:
|
| 270 |
seg = seg.strip()
|
|
@@ -272,42 +384,73 @@ def normalize_text(text: str) -> str:
|
|
| 272 |
continue
|
| 273 |
|
| 274 |
# Extract tags
|
| 275 |
-
m = re.match(r
|
| 276 |
-
tag, content = m.groups() if m else (
|
| 277 |
|
| 278 |
# Remove irrelevant symbols
|
| 279 |
content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
|
| 280 |
|
| 281 |
# Handle consecutive "哈" characters: replace 2 or more with "(笑)"
|
| 282 |
-
content = re.sub(r
|
|
|
|
|
|
|
|
|
|
| 283 |
|
| 284 |
# First handle multi-character punctuation marks
|
| 285 |
-
content = content.replace(
|
| 286 |
-
content = content.replace(
|
| 287 |
|
| 288 |
# Handle single-character internal punctuation marks
|
| 289 |
-
internal_punct_map = str.maketrans(
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
':': ',', ':': ',',
|
| 293 |
-
'、': ',',
|
| 294 |
-
'?': ',', '?': ','
|
| 295 |
-
})
|
| 296 |
content = content.translate(internal_punct_map)
|
| 297 |
content = content.strip()
|
| 298 |
|
| 299 |
# Keep only the final period
|
| 300 |
if len(content) > 1:
|
| 301 |
-
last_ch =
|
| 302 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
content = body + last_ch
|
| 304 |
|
| 305 |
-
|
| 306 |
|
| 307 |
-
|
|
|
|
| 308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
"""Process a batch of data items and generate audio, return audio data and metadata"""
|
| 312 |
try:
|
| 313 |
# Prepare batch data
|
|
@@ -316,64 +459,74 @@ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, sta
|
|
| 316 |
prompts = [system_prompt] * batch_size
|
| 317 |
prompt_audios = []
|
| 318 |
actual_texts_data = [] # Store actual text data used
|
| 319 |
-
|
| 320 |
print(f"Processing {batch_size} samples starting from index {start_idx}...")
|
| 321 |
-
|
| 322 |
# Extract text and audio from each sample
|
| 323 |
for i, item in enumerate(batch_items):
|
| 324 |
# Use new processing function
|
| 325 |
processed_item = process_jsonl_item(item)
|
| 326 |
-
|
| 327 |
text = processed_item["text"]
|
| 328 |
prompt_text = processed_item["prompt_text"]
|
| 329 |
-
|
| 330 |
-
# Merge text
|
| 331 |
-
full_text = prompt_text + text
|
| 332 |
original_full_text = full_text # Save original text
|
| 333 |
-
|
| 334 |
# Apply text normalization based on parameter
|
| 335 |
if use_normalize:
|
| 336 |
full_text = normalize_text(full_text)
|
| 337 |
-
|
| 338 |
# Replace speaker tags
|
| 339 |
-
final_text = full_text.replace("[S1]", "<speaker1>").replace(
|
|
|
|
|
|
|
| 340 |
texts.append(final_text)
|
| 341 |
-
|
| 342 |
# Save actual text information used
|
| 343 |
-
actual_texts_data.append(
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
# Get reference audio
|
| 352 |
prompt_audios.append(processed_item["prompt_audio"])
|
| 353 |
-
|
| 354 |
# Process inputs
|
| 355 |
input_ids_list = []
|
| 356 |
-
for i, (text, prompt, audio_path) in enumerate(
|
|
|
|
|
|
|
| 357 |
# Load audio data here
|
| 358 |
audio_data = load_audio_data(audio_path) if audio_path else None
|
| 359 |
-
inputs = process_inputs(
|
|
|
|
|
|
|
| 360 |
inputs = shifting_inputs(inputs, tokenizer)
|
| 361 |
input_ids_list.append(inputs)
|
| 362 |
-
|
| 363 |
# Pad batch inputs
|
| 364 |
input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer)
|
| 365 |
-
|
| 366 |
# Batch generation
|
| 367 |
print(f"Starting batch audio generation...")
|
| 368 |
start = input_ids.shape[1] - MAX_CHANNELS + 1
|
| 369 |
-
|
| 370 |
# Move inputs to GPU
|
| 371 |
input_ids = input_ids.to(device)
|
| 372 |
attention_mask = attention_mask.to(device)
|
| 373 |
-
|
| 374 |
# Generate model outputs
|
| 375 |
outputs = model.generate(
|
| 376 |
-
input_ids=input_ids,
|
| 377 |
attention_mask=attention_mask,
|
| 378 |
)
|
| 379 |
print(f"Original outputs shape: {outputs.shape}")
|
|
@@ -385,20 +538,19 @@ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, sta
|
|
| 385 |
outputs = outputs[:, start:]
|
| 386 |
seq_len = outputs.shape[1] - MAX_CHANNELS + 1
|
| 387 |
speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device)
|
| 388 |
-
|
| 389 |
-
|
| 390 |
# Adjust output format
|
| 391 |
for j in range(MAX_CHANNELS):
|
| 392 |
speech_ids[..., j] = outputs[:, j : seq_len + j, j]
|
| 393 |
-
if j == 0:
|
| 394 |
speech_ids[..., j] = speech_ids[..., j] - 151665
|
| 395 |
-
|
| 396 |
# Find valid positions for each sample
|
| 397 |
li = find_max_valid_positions(speech_ids)
|
| 398 |
-
|
| 399 |
# Store audio result data
|
| 400 |
audio_results = []
|
| 401 |
-
|
| 402 |
# Process batch sample results individually
|
| 403 |
for i in range(batch_size):
|
| 404 |
try:
|
|
@@ -408,39 +560,200 @@ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, sta
|
|
| 408 |
print(f"Sample {start_idx + i} has no valid speech tokens")
|
| 409 |
audio_results.append(None)
|
| 410 |
continue
|
| 411 |
-
|
| 412 |
this_speech_id = speech_ids[i, :end_idx]
|
| 413 |
-
print(
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
except Exception as e:
|
| 433 |
print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...")
|
| 434 |
import traceback
|
|
|
|
| 435 |
traceback.print_exc()
|
| 436 |
audio_results.append(None)
|
| 437 |
-
|
| 438 |
# Clean up GPU memory
|
| 439 |
torch.cuda.empty_cache()
|
| 440 |
-
|
| 441 |
# Return text data and audio data
|
| 442 |
return actual_texts_data, audio_results
|
| 443 |
-
|
| 444 |
except Exception as e:
|
| 445 |
print(f"Error during batch processing: {str(e)}")
|
| 446 |
-
raise
|
|
|
|
| 1 |
import os
|
| 2 |
import re
|
| 3 |
|
| 4 |
+
import numpy as np
|
| 5 |
import torch
|
| 6 |
import torchaudio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
MAX_CHANNELS = 8
|
|
|
|
| 9 |
|
| 10 |
+
def pad_or_truncate_to_seconds(
|
| 11 |
+
wav: torch.Tensor, target_seconds: float, sr: int
|
| 12 |
+
) -> torch.Tensor:
|
| 13 |
+
"""Pad or truncate a mono waveform to target length in seconds.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
wav: (1, T) or (T,) tensor
|
| 17 |
+
target_seconds: target duration in seconds
|
| 18 |
+
sr: sample rate
|
| 19 |
+
Returns:
|
| 20 |
+
(1, T_target) tensor
|
| 21 |
+
"""
|
| 22 |
+
if wav.dim() == 2 and wav.shape[0] == 1:
|
| 23 |
+
wav_1d = wav.squeeze(0)
|
| 24 |
+
else:
|
| 25 |
+
wav_1d = wav.reshape(-1)
|
| 26 |
+
target_len = int(round(target_seconds * sr))
|
| 27 |
+
cur_len = wav_1d.shape[-1]
|
| 28 |
+
if cur_len == target_len:
|
| 29 |
+
out = wav_1d
|
| 30 |
+
elif cur_len > target_len:
|
| 31 |
+
out = wav_1d[:target_len]
|
| 32 |
+
else:
|
| 33 |
+
pad_len = target_len - cur_len
|
| 34 |
+
out = torch.cat(
|
| 35 |
+
[wav_1d, torch.zeros(pad_len, dtype=wav_1d.dtype, device=wav_1d.device)],
|
| 36 |
+
dim=-1,
|
| 37 |
+
)
|
| 38 |
+
return out.unsqueeze(0)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def crossfade_concat(
|
| 42 |
+
segments: list, sample_rate: int, crossfade_seconds: float = 0.1
|
| 43 |
+
) -> torch.Tensor:
|
| 44 |
+
"""Concatenate segments with linear crossfade.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
segments: list of (1, T) tensors
|
| 48 |
+
sample_rate: sampling rate
|
| 49 |
+
crossfade_seconds: overlap time for crossfade
|
| 50 |
+
Returns:
|
| 51 |
+
(1, T_total) tensor
|
| 52 |
+
"""
|
| 53 |
+
if len(segments) == 0:
|
| 54 |
+
return torch.zeros(1, 0)
|
| 55 |
+
if len(segments) == 1:
|
| 56 |
+
return segments[0]
|
| 57 |
+
out = segments[0]
|
| 58 |
+
cf_len_target = int(round(crossfade_seconds * sample_rate))
|
| 59 |
+
for k in range(1, len(segments)):
|
| 60 |
+
nxt = segments[k]
|
| 61 |
+
if cf_len_target <= 0:
|
| 62 |
+
out = torch.cat([out, nxt], dim=-1)
|
| 63 |
+
continue
|
| 64 |
+
cf_len = min(cf_len_target, out.shape[-1], nxt.shape[-1])
|
| 65 |
+
if cf_len <= 0:
|
| 66 |
+
out = torch.cat([out, nxt], dim=-1)
|
| 67 |
+
continue
|
| 68 |
+
fade_out = torch.linspace(
|
| 69 |
+
1.0, 0.0, steps=cf_len, dtype=out.dtype, device=out.device
|
| 70 |
+
)
|
| 71 |
+
fade_in = torch.linspace(
|
| 72 |
+
0.0, 1.0, steps=cf_len, dtype=nxt.dtype, device=nxt.device
|
| 73 |
+
)
|
| 74 |
+
overlap = out[0, -cf_len:] * fade_out + nxt[0, :cf_len] * fade_in
|
| 75 |
+
out = torch.cat(
|
| 76 |
+
[out[:, :-cf_len], overlap.unsqueeze(0), nxt[:, cf_len:]], dim=-1
|
| 77 |
+
)
|
| 78 |
+
return out
|
| 79 |
+
|
| 80 |
+
def load_model(
|
| 81 |
+
model_path,
|
| 82 |
+
spt_config_path,
|
| 83 |
+
spt_checkpoint_path,
|
| 84 |
+
torch_dtype=torch.bfloat16,
|
| 85 |
+
attn_implementation="sdpa",
|
| 86 |
+
):
|
| 87 |
+
from transformers import AutoTokenizer
|
| 88 |
+
|
| 89 |
+
from modeling_asteroid import AsteroidTTSInstruct
|
| 90 |
+
from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
|
| 91 |
+
|
| 92 |
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 93 |
+
model = AsteroidTTSInstruct.from_pretrained(
|
| 94 |
+
model_path, torch_dtype=torch_dtype, attn_implementation=attn_implementation
|
| 95 |
+
)
|
| 96 |
+
spt = XY_Tokenizer.load_from_checkpoint(
|
| 97 |
+
config_path=spt_config_path, ckpt_path=spt_checkpoint_path
|
| 98 |
+
)
|
| 99 |
|
|
|
|
|
|
|
| 100 |
model.eval()
|
| 101 |
spt.eval()
|
| 102 |
return tokenizer, model, spt
|
| 103 |
|
| 104 |
|
| 105 |
def process_jsonl_item(item):
|
| 106 |
+
"""Parse a JSONL item enforcing prompt requirement.
|
| 107 |
+
|
| 108 |
+
Only supports Format 1 (separate speaker refs) and Format 2 (shared ref),
|
| 109 |
+
consistent with the updated README. If `base_path` is missing/empty, any
|
| 110 |
+
string paths must be absolute. Text-only input is not supported and will raise.
|
| 111 |
+
"""
|
| 112 |
+
base_path = item.get("base_path", "") or ""
|
| 113 |
text = item.get("text", "")
|
| 114 |
+
|
| 115 |
+
def _resolve_path(p: str) -> str:
|
| 116 |
+
if not isinstance(p, str) or not p:
|
| 117 |
+
return p
|
| 118 |
+
if base_path:
|
| 119 |
+
return os.path.join(base_path, p)
|
| 120 |
+
# base_path missing: require absolute path
|
| 121 |
+
if not os.path.isabs(p):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"When base_path is omitted, audio paths must be absolute. Got: " + p
|
| 124 |
+
)
|
| 125 |
+
return p
|
| 126 |
+
|
| 127 |
+
# Try Format 2 first: shared audio reference
|
| 128 |
+
prompt_audio = None
|
| 129 |
+
prompt_text = ""
|
| 130 |
+
if "prompt_audio" in item:
|
| 131 |
+
prompt_audio_val = item.get("prompt_audio")
|
| 132 |
+
if not prompt_audio_val:
|
| 133 |
+
raise ValueError("Format 2 requires non-empty 'prompt_audio'.")
|
| 134 |
+
if isinstance(prompt_audio_val, str):
|
| 135 |
+
prompt_audio = _resolve_path(prompt_audio_val)
|
| 136 |
+
else:
|
| 137 |
+
# allow tuple form for backward-compatibility
|
| 138 |
+
prompt_audio = prompt_audio_val
|
| 139 |
+
prompt_text = item.get("prompt_text", "")
|
| 140 |
+
return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
|
| 141 |
+
|
| 142 |
+
# Try Format 1: separate speaker references
|
| 143 |
+
s1 = item.get("prompt_audio_speaker1", "")
|
| 144 |
+
s2 = item.get("prompt_audio_speaker2", "")
|
| 145 |
+
has_s1 = (isinstance(s1, str) and s1) or isinstance(s1, tuple)
|
| 146 |
+
has_s2 = (isinstance(s2, str) and s2) or isinstance(s2, tuple)
|
| 147 |
+
|
| 148 |
+
if has_s1 and has_s2:
|
| 149 |
+
if isinstance(s1, str) and s1:
|
| 150 |
+
s1_resolved = _resolve_path(s1)
|
| 151 |
else:
|
| 152 |
+
s1_resolved = s1
|
| 153 |
+
if isinstance(s2, str) and s2:
|
| 154 |
+
s2_resolved = _resolve_path(s2)
|
|
|
|
| 155 |
else:
|
| 156 |
+
s2_resolved = s2
|
| 157 |
+
# Build merged prompt audio dict
|
| 158 |
+
prompt_audio = {"speaker1": s1_resolved, "speaker2": s2_resolved}
|
| 159 |
+
# Merge texts
|
| 160 |
+
pt1 = item.get("prompt_text_speaker1", "")
|
| 161 |
+
pt2 = item.get("prompt_text_speaker2", "")
|
| 162 |
+
merged = ""
|
| 163 |
+
if pt1:
|
| 164 |
+
merged += f"[S1]{pt1}"
|
| 165 |
+
if pt2:
|
| 166 |
+
merged += f"[S2]{pt2}"
|
| 167 |
+
prompt_text = merged.strip()
|
| 168 |
+
return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
|
| 169 |
+
|
| 170 |
+
# Otherwise, no supported prompt found → reject (text-only unsupported)
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"Input must include prompt (Format 1 or 2). Text-only is not supported."
|
| 173 |
+
)
|
|
|
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
def load_audio_data(prompt_audio, target_sample_rate=16000):
|
| 177 |
"""Load audio data and return processed audio tensor
|
| 178 |
+
|
| 179 |
Args:
|
| 180 |
prompt_audio: Can be in the following formats:
|
| 181 |
- String: audio file path
|
|
|
|
| 184 |
"""
|
| 185 |
if prompt_audio is None:
|
| 186 |
return None
|
| 187 |
+
|
| 188 |
try:
|
| 189 |
# Check if prompt_audio is a dictionary (containing speaker1 and speaker2)
|
| 190 |
+
if (
|
| 191 |
+
isinstance(prompt_audio, dict)
|
| 192 |
+
and "speaker1" in prompt_audio
|
| 193 |
+
and "speaker2" in prompt_audio
|
| 194 |
+
):
|
| 195 |
# Process audio from both speakers separately
|
| 196 |
wav1, sr1 = _load_single_audio(prompt_audio["speaker1"])
|
| 197 |
wav2, sr2 = _load_single_audio(prompt_audio["speaker2"])
|
|
|
|
| 203 |
# Single audio
|
| 204 |
wav, sr = _load_single_audio(prompt_audio)
|
| 205 |
# Resample to 16k
|
| 206 |
+
if sr != target_sample_rate:
|
| 207 |
wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
|
| 208 |
# Ensure mono channel
|
| 209 |
if wav.shape[0] > 1:
|
| 210 |
wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono
|
| 211 |
+
if len(wav.shape) == 1:
|
| 212 |
wav = wav.unsqueeze(0)
|
| 213 |
+
|
| 214 |
return wav
|
| 215 |
except Exception as e:
|
| 216 |
print(f"Error loading audio data: {e}")
|
|
|
|
| 219 |
|
| 220 |
def _load_single_audio(audio_input):
|
| 221 |
"""Load single audio, supports file path or (wav, sr) tuple
|
| 222 |
+
|
| 223 |
Args:
|
| 224 |
audio_input: String (file path) or tuple (wav, sr)
|
| 225 |
+
|
| 226 |
Returns:
|
| 227 |
tuple: (wav, sr)
|
| 228 |
"""
|
|
|
|
| 249 |
wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono
|
| 250 |
if len(wav1.shape) == 1:
|
| 251 |
wav1 = wav1.unsqueeze(0)
|
| 252 |
+
|
| 253 |
+
# Process second audio
|
| 254 |
if sr2 != target_sample_rate:
|
| 255 |
wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate)
|
| 256 |
# Ensure mono channel
|
|
|
|
| 258 |
wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono
|
| 259 |
if len(wav2.shape) == 1:
|
| 260 |
wav2 = wav2.unsqueeze(0)
|
| 261 |
+
|
| 262 |
# Concatenate audio
|
| 263 |
merged_wav = torch.cat([wav1, wav2], dim=1)
|
| 264 |
return merged_wav
|
|
|
|
| 267 |
raise
|
| 268 |
|
| 269 |
|
| 270 |
+
def process_inputs(
|
| 271 |
+
tokenizer,
|
| 272 |
+
spt,
|
| 273 |
+
prompt,
|
| 274 |
+
text,
|
| 275 |
+
device,
|
| 276 |
+
silence_duration,
|
| 277 |
+
audio_data=None,
|
| 278 |
+
max_channels=8,
|
| 279 |
+
pad_token=1024,
|
| 280 |
+
):
|
| 281 |
seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>"
|
| 282 |
inputs1 = np.array(tokenizer.encode(seq))
|
| 283 |
input_ids = np.full((inputs1.shape[0], max_channels), pad_token)
|
| 284 |
input_ids[:, 0] = inputs1
|
| 285 |
+
|
| 286 |
if audio_data is not None:
|
| 287 |
try:
|
| 288 |
# audio_data should now be a processed audio tensor
|
| 289 |
wav = audio_data
|
| 290 |
+
|
| 291 |
# Add fixed 5-second silence at the end of audio (using 16k sample rate)
|
| 292 |
+
silence_samples = int(silence_duration * 16000)
|
| 293 |
silence = torch.zeros(wav.shape[0], silence_samples)
|
| 294 |
wav = torch.cat([wav, silence], dim=1)
|
| 295 |
+
|
| 296 |
with torch.no_grad():
|
| 297 |
# Use SPT encoding
|
| 298 |
encode_result = spt.encode([wav.squeeze().to(device)])
|
| 299 |
+
audio_token = (
|
| 300 |
+
encode_result["codes_list"][0].permute(1, 0).cpu().numpy()
|
| 301 |
+
) # Adjust dimension order
|
| 302 |
+
|
| 303 |
# similar to DAC encoding adjustment
|
| 304 |
+
audio_token[:, 0] = (
|
| 305 |
+
audio_token[:, 0] + 151665
|
| 306 |
+
) # Keep this line if offset is needed, otherwise delete
|
| 307 |
input_ids = np.concatenate([input_ids, audio_token])
|
| 308 |
except Exception as e:
|
| 309 |
print(f"Error processing audio data: {e}")
|
| 310 |
raise
|
| 311 |
+
|
| 312 |
return input_ids
|
| 313 |
|
| 314 |
|
|
|
|
| 316 |
seq_len = input_ids.shape[0]
|
| 317 |
new_seq_len = seq_len + max_channels - 1
|
| 318 |
shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64)
|
| 319 |
+
shifted_input_ids[:, 0] = np.full(
|
| 320 |
+
new_seq_len, tokenizer.pad_token_id, dtype=np.int64
|
| 321 |
+
)
|
| 322 |
for i in range(max_channels):
|
| 323 |
shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i]
|
| 324 |
return shifted_input_ids
|
|
|
|
| 328 |
attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids]
|
| 329 |
max_length = max(ids.shape[0] for ids in input_ids)
|
| 330 |
padded_input_ids, padded_attns = [], []
|
| 331 |
+
|
| 332 |
for ids, attn in zip(input_ids, attention_masks):
|
| 333 |
pad_len = max_length - ids.shape[0]
|
| 334 |
input_pad = np.full((pad_len, channels), 1024)
|
|
|
|
| 360 |
Normalize multi-speaker script.
|
| 361 |
|
| 362 |
1. Don't preserve line breaks.
|
| 363 |
+
2. Preserve bracketed segments like [] () <> even when they are not speaker tags.
|
| 364 |
+
3. Remove decorative symbols: 【】《》()『』「」~~-_.
|
| 365 |
+
4. Internal punctuation ;:、 → ,;keep ?!?.
|
| 366 |
5. Multiple 。 keep only the last one, others → ,。
|
| 367 |
6. Replace consecutive "哈" (>=2) with "(笑)".
|
| 368 |
7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
|
| 369 |
+
8. Merge adjacent identical speaker tags.
|
| 370 |
"""
|
| 371 |
# Replace [1], [2] etc. format with [S1], [S2] etc. format
|
| 372 |
+
text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
|
| 373 |
|
| 374 |
# Remove decorative characters
|
| 375 |
+
remove_chars = "【】《》()『』「」" '"-_“”~~‘’'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
|
| 377 |
# Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
|
| 378 |
+
segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
|
| 379 |
+
processed_parts = []
|
| 380 |
|
| 381 |
for seg in segments:
|
| 382 |
seg = seg.strip()
|
|
|
|
| 384 |
continue
|
| 385 |
|
| 386 |
# Extract tags
|
| 387 |
+
m = re.match(r"^(\[S\d+\])\s*(.*)", seg)
|
| 388 |
+
tag, content = m.groups() if m else ("", seg)
|
| 389 |
|
| 390 |
# Remove irrelevant symbols
|
| 391 |
content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
|
| 392 |
|
| 393 |
# Handle consecutive "哈" characters: replace 2 or more with "(笑)"
|
| 394 |
+
content = re.sub(r"哈{2,}", "[笑]", content)
|
| 395 |
+
|
| 396 |
+
# Handle English laughter (e.g., "haha", "ha ha")
|
| 397 |
+
content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE)
|
| 398 |
|
| 399 |
# First handle multi-character punctuation marks
|
| 400 |
+
content = content.replace("——", ",")
|
| 401 |
+
content = content.replace("……", ",")
|
| 402 |
|
| 403 |
# Handle single-character internal punctuation marks
|
| 404 |
+
internal_punct_map = str.maketrans(
|
| 405 |
+
{";": ",", ";": ",", ":": ",", ":": ",", "、": ","}
|
| 406 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 407 |
content = content.translate(internal_punct_map)
|
| 408 |
content = content.strip()
|
| 409 |
|
| 410 |
# Keep only the final period
|
| 411 |
if len(content) > 1:
|
| 412 |
+
last_ch = (
|
| 413 |
+
"。"
|
| 414 |
+
if content[-1] == ","
|
| 415 |
+
else ("." if content[-1] == "," else content[-1])
|
| 416 |
+
)
|
| 417 |
+
body = content[:-1].replace("。", ",")
|
| 418 |
content = body + last_ch
|
| 419 |
|
| 420 |
+
processed_parts.append({"tag": tag, "content": content})
|
| 421 |
|
| 422 |
+
if not processed_parts:
|
| 423 |
+
return ""
|
| 424 |
|
| 425 |
+
# Merge consecutive same speakers
|
| 426 |
+
merged_lines = []
|
| 427 |
+
current_tag = processed_parts[0]["tag"]
|
| 428 |
+
current_content = [processed_parts[0]["content"]]
|
| 429 |
|
| 430 |
+
for part in processed_parts[1:]:
|
| 431 |
+
if part["tag"] == current_tag and current_tag:
|
| 432 |
+
current_content.append(part["content"])
|
| 433 |
+
else:
|
| 434 |
+
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
|
| 435 |
+
current_tag = part["tag"]
|
| 436 |
+
current_content = [part["content"]]
|
| 437 |
+
|
| 438 |
+
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
|
| 439 |
+
|
| 440 |
+
return "".join(merged_lines).replace("‘", "'").replace("’", "'")
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def process_batch(
|
| 444 |
+
batch_items,
|
| 445 |
+
tokenizer,
|
| 446 |
+
model,
|
| 447 |
+
spt,
|
| 448 |
+
device,
|
| 449 |
+
system_prompt,
|
| 450 |
+
start_idx,
|
| 451 |
+
use_normalize=False,
|
| 452 |
+
silence_duration=0,
|
| 453 |
+
):
|
| 454 |
"""Process a batch of data items and generate audio, return audio data and metadata"""
|
| 455 |
try:
|
| 456 |
# Prepare batch data
|
|
|
|
| 459 |
prompts = [system_prompt] * batch_size
|
| 460 |
prompt_audios = []
|
| 461 |
actual_texts_data = [] # Store actual text data used
|
| 462 |
+
|
| 463 |
print(f"Processing {batch_size} samples starting from index {start_idx}...")
|
| 464 |
+
|
| 465 |
# Extract text and audio from each sample
|
| 466 |
for i, item in enumerate(batch_items):
|
| 467 |
# Use new processing function
|
| 468 |
processed_item = process_jsonl_item(item)
|
| 469 |
+
|
| 470 |
text = processed_item["text"]
|
| 471 |
prompt_text = processed_item["prompt_text"]
|
| 472 |
+
|
| 473 |
+
# Merge text, if prompt_text is empty, full_text is just text
|
| 474 |
+
full_text = prompt_text + text if prompt_text else text
|
| 475 |
original_full_text = full_text # Save original text
|
| 476 |
+
|
| 477 |
# Apply text normalization based on parameter
|
| 478 |
if use_normalize:
|
| 479 |
full_text = normalize_text(full_text)
|
| 480 |
+
|
| 481 |
# Replace speaker tags
|
| 482 |
+
final_text = full_text.replace("[S1]", "<speaker1>").replace(
|
| 483 |
+
"[S2]", "<speaker2>"
|
| 484 |
+
)
|
| 485 |
texts.append(final_text)
|
| 486 |
+
|
| 487 |
# Save actual text information used
|
| 488 |
+
actual_texts_data.append(
|
| 489 |
+
{
|
| 490 |
+
"index": start_idx + i,
|
| 491 |
+
"original_text": original_full_text,
|
| 492 |
+
"normalized_text": (
|
| 493 |
+
normalize_text(original_full_text) if use_normalize else None
|
| 494 |
+
),
|
| 495 |
+
"final_text": final_text,
|
| 496 |
+
"use_normalize": use_normalize,
|
| 497 |
+
}
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
# Get reference audio
|
| 501 |
prompt_audios.append(processed_item["prompt_audio"])
|
| 502 |
+
|
| 503 |
# Process inputs
|
| 504 |
input_ids_list = []
|
| 505 |
+
for i, (text, prompt, audio_path) in enumerate(
|
| 506 |
+
zip(texts, prompts, prompt_audios)
|
| 507 |
+
):
|
| 508 |
# Load audio data here
|
| 509 |
audio_data = load_audio_data(audio_path) if audio_path else None
|
| 510 |
+
inputs = process_inputs(
|
| 511 |
+
tokenizer, spt, prompt, text, device, silence_duration, audio_data
|
| 512 |
+
)
|
| 513 |
inputs = shifting_inputs(inputs, tokenizer)
|
| 514 |
input_ids_list.append(inputs)
|
| 515 |
+
|
| 516 |
# Pad batch inputs
|
| 517 |
input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer)
|
| 518 |
+
|
| 519 |
# Batch generation
|
| 520 |
print(f"Starting batch audio generation...")
|
| 521 |
start = input_ids.shape[1] - MAX_CHANNELS + 1
|
| 522 |
+
|
| 523 |
# Move inputs to GPU
|
| 524 |
input_ids = input_ids.to(device)
|
| 525 |
attention_mask = attention_mask.to(device)
|
| 526 |
+
|
| 527 |
# Generate model outputs
|
| 528 |
outputs = model.generate(
|
| 529 |
+
input_ids=input_ids,
|
| 530 |
attention_mask=attention_mask,
|
| 531 |
)
|
| 532 |
print(f"Original outputs shape: {outputs.shape}")
|
|
|
|
| 538 |
outputs = outputs[:, start:]
|
| 539 |
seq_len = outputs.shape[1] - MAX_CHANNELS + 1
|
| 540 |
speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device)
|
| 541 |
+
|
|
|
|
| 542 |
# Adjust output format
|
| 543 |
for j in range(MAX_CHANNELS):
|
| 544 |
speech_ids[..., j] = outputs[:, j : seq_len + j, j]
|
| 545 |
+
if j == 0:
|
| 546 |
speech_ids[..., j] = speech_ids[..., j] - 151665
|
| 547 |
+
|
| 548 |
# Find valid positions for each sample
|
| 549 |
li = find_max_valid_positions(speech_ids)
|
| 550 |
+
|
| 551 |
# Store audio result data
|
| 552 |
audio_results = []
|
| 553 |
+
|
| 554 |
# Process batch sample results individually
|
| 555 |
for i in range(batch_size):
|
| 556 |
try:
|
|
|
|
| 560 |
print(f"Sample {start_idx + i} has no valid speech tokens")
|
| 561 |
audio_results.append(None)
|
| 562 |
continue
|
| 563 |
+
|
| 564 |
this_speech_id = speech_ids[i, :end_idx]
|
| 565 |
+
print(
|
| 566 |
+
f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}"
|
| 567 |
+
)
|
| 568 |
+
|
| 569 |
+
# Prompt-Augmented Decode (rvq8-style); fall back to original decode if no prompt
|
| 570 |
+
prompt_audio = prompt_audios[i]
|
| 571 |
+
if prompt_audio is None:
|
| 572 |
+
# Fallback to original decode
|
| 573 |
+
with torch.no_grad():
|
| 574 |
+
codes_list = [this_speech_id.permute(1, 0)]
|
| 575 |
+
decode_result = spt.decode(codes_list, overlap_seconds=10)
|
| 576 |
+
audio_out = decode_result["syn_wav_list"][0].cpu().detach()
|
| 577 |
+
if audio_out.ndim == 1:
|
| 578 |
+
audio_out = audio_out.unsqueeze(0)
|
| 579 |
+
audio_results.append(
|
| 580 |
+
{
|
| 581 |
+
"audio_data": audio_out,
|
| 582 |
+
"sample_rate": spt.output_sample_rate,
|
| 583 |
+
"index": start_idx + i,
|
| 584 |
+
}
|
| 585 |
+
)
|
| 586 |
+
print(f"Audio generation completed (orig): sample {start_idx + i}")
|
| 587 |
+
else:
|
| 588 |
+
# 1) Load prompt at SPT input sr and force to 20s
|
| 589 |
+
ref_sr_in = (
|
| 590 |
+
getattr(spt, "input_sample_rate", None)
|
| 591 |
+
or getattr(spt, "sampling_rate", None)
|
| 592 |
+
or 24000
|
| 593 |
+
)
|
| 594 |
+
ref_wav = load_audio_data(
|
| 595 |
+
prompt_audio, target_sample_rate=ref_sr_in
|
| 596 |
+
)
|
| 597 |
+
if ref_wav is None:
|
| 598 |
+
# If ref missing, use original decode
|
| 599 |
+
with torch.no_grad():
|
| 600 |
+
codes_list = [this_speech_id.permute(1, 0)]
|
| 601 |
+
decode_result = spt.decode(codes_list, overlap_seconds=10)
|
| 602 |
+
audio_out = decode_result["syn_wav_list"][0].cpu().detach()
|
| 603 |
+
if audio_out.ndim == 1:
|
| 604 |
+
audio_out = audio_out.unsqueeze(0)
|
| 605 |
+
audio_results.append(
|
| 606 |
+
{
|
| 607 |
+
"audio_data": audio_out,
|
| 608 |
+
"sample_rate": spt.output_sample_rate,
|
| 609 |
+
"index": start_idx + i,
|
| 610 |
+
}
|
| 611 |
+
)
|
| 612 |
+
print(
|
| 613 |
+
f"Audio generation completed (orig no-ref): sample {start_idx + i}"
|
| 614 |
+
)
|
| 615 |
+
else:
|
| 616 |
+
# Encode 20s reference to tokens
|
| 617 |
+
ref_wav_20s = pad_or_truncate_to_seconds(
|
| 618 |
+
ref_wav, 20.0, ref_sr_in
|
| 619 |
+
).to(device)
|
| 620 |
+
with torch.no_grad():
|
| 621 |
+
enc = spt.encode([ref_wav_20s.squeeze(0)])
|
| 622 |
+
ref_codes = (
|
| 623 |
+
enc["codes_list"][0].to(device).long()
|
| 624 |
+
) # (nq, T_ref)
|
| 625 |
+
|
| 626 |
+
# Prepare token-to-sample mapping and windowing params
|
| 627 |
+
out_sr = (
|
| 628 |
+
getattr(spt, "output_sample_rate", None)
|
| 629 |
+
or getattr(spt, "sample_rate", None)
|
| 630 |
+
or 24000
|
| 631 |
+
)
|
| 632 |
+
tokens_per_second = float(ref_sr_in) / float(
|
| 633 |
+
spt.encoder_downsample_rate
|
| 634 |
+
)
|
| 635 |
+
tokens_per_chunk = int(round(10.0 * tokens_per_second))
|
| 636 |
+
stride_tokens = 85
|
| 637 |
+
keep_tokens = 85
|
| 638 |
+
left_ctx_tokens = 20
|
| 639 |
+
total_tokens = this_speech_id.shape[0]
|
| 640 |
+
samples_per_token = int(round(out_sr / tokens_per_second))
|
| 641 |
+
crossfade_seconds = 0.1
|
| 642 |
+
crossfade_samples = int(round(crossfade_seconds * out_sr))
|
| 643 |
+
|
| 644 |
+
kept_segments = []
|
| 645 |
+
chunk_idx = 0
|
| 646 |
+
while True:
|
| 647 |
+
st_tok = chunk_idx * stride_tokens
|
| 648 |
+
if st_tok >= total_tokens:
|
| 649 |
+
break
|
| 650 |
+
ed_tok = min(st_tok + tokens_per_chunk, total_tokens)
|
| 651 |
+
gen_chunk = this_speech_id[st_tok:ed_tok] # (len, C)
|
| 652 |
+
if gen_chunk.shape[0] == 0:
|
| 653 |
+
break
|
| 654 |
+
|
| 655 |
+
# Concatenate reference tokens with current window tokens
|
| 656 |
+
combined_codes = torch.cat(
|
| 657 |
+
[ref_codes, gen_chunk.permute(1, 0).long()], dim=1
|
| 658 |
+
).to(
|
| 659 |
+
device
|
| 660 |
+
) # (nq, T_ref + T_chunk)
|
| 661 |
+
codes_lengths = torch.tensor(
|
| 662 |
+
[combined_codes.shape[-1]],
|
| 663 |
+
dtype=torch.long,
|
| 664 |
+
device=device,
|
| 665 |
+
)
|
| 666 |
+
combined_codes_batched = combined_codes.unsqueeze(
|
| 667 |
+
1
|
| 668 |
+
) # (nq, 1, T)
|
| 669 |
+
|
| 670 |
+
with torch.no_grad():
|
| 671 |
+
detok = spt.inference_detokenize(
|
| 672 |
+
combined_codes_batched, codes_lengths
|
| 673 |
+
)
|
| 674 |
+
y = detok["y"][0, 0] # (T_samples)
|
| 675 |
+
|
| 676 |
+
# Remove 20s reference portion (in samples)
|
| 677 |
+
ref_samples = int(round(20.0 * out_sr))
|
| 678 |
+
if y.shape[-1] <= ref_samples:
|
| 679 |
+
chunk_idx += 1
|
| 680 |
+
continue
|
| 681 |
+
chunk_y = y[ref_samples:]
|
| 682 |
+
|
| 683 |
+
# Determine kept region within current window
|
| 684 |
+
window_len = gen_chunk.shape[0]
|
| 685 |
+
remains = total_tokens - st_tok
|
| 686 |
+
is_first = chunk_idx == 0
|
| 687 |
+
is_last = ed_tok >= total_tokens
|
| 688 |
+
|
| 689 |
+
if is_first:
|
| 690 |
+
keep_start_tok = 0
|
| 691 |
+
keep_end_tok = min(
|
| 692 |
+
keep_tokens + left_ctx_tokens, window_len
|
| 693 |
+
)
|
| 694 |
+
elif is_last and remains < 105:
|
| 695 |
+
keep_start_tok = (
|
| 696 |
+
0 if is_first else min(left_ctx_tokens, window_len)
|
| 697 |
+
)
|
| 698 |
+
keep_end_tok = window_len
|
| 699 |
+
else:
|
| 700 |
+
keep_start_tok = min(left_ctx_tokens, window_len)
|
| 701 |
+
keep_end_tok = min(
|
| 702 |
+
left_ctx_tokens + keep_tokens, window_len
|
| 703 |
+
)
|
| 704 |
+
|
| 705 |
+
keep_start_smps = keep_start_tok * samples_per_token
|
| 706 |
+
keep_end_smps = keep_end_tok * samples_per_token
|
| 707 |
+
left_margin = 0
|
| 708 |
+
right_margin = crossfade_samples if not is_last else 0
|
| 709 |
+
seg_start = max(0, keep_start_smps - left_margin)
|
| 710 |
+
seg_end = min(
|
| 711 |
+
chunk_y.shape[-1], keep_end_smps + right_margin
|
| 712 |
+
)
|
| 713 |
+
if seg_end > seg_start:
|
| 714 |
+
kept_segments.append(
|
| 715 |
+
chunk_y[seg_start:seg_end]
|
| 716 |
+
.detach()
|
| 717 |
+
.cpu()
|
| 718 |
+
.unsqueeze(0)
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
chunk_idx += 1
|
| 722 |
+
|
| 723 |
+
# Concatenate with crossfade; if empty, return tiny silence
|
| 724 |
+
if len(kept_segments) == 0:
|
| 725 |
+
audio_out = torch.zeros(1, int(0.01 * out_sr))
|
| 726 |
+
else:
|
| 727 |
+
audio_out = crossfade_concat(
|
| 728 |
+
kept_segments,
|
| 729 |
+
out_sr,
|
| 730 |
+
crossfade_seconds=crossfade_seconds,
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
audio_results.append(
|
| 734 |
+
{
|
| 735 |
+
"audio_data": audio_out,
|
| 736 |
+
"sample_rate": out_sr,
|
| 737 |
+
"index": start_idx + i,
|
| 738 |
+
}
|
| 739 |
+
)
|
| 740 |
+
print(
|
| 741 |
+
f"Audio generation completed (prompt-aug): sample {start_idx + i}"
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
except Exception as e:
|
| 745 |
print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...")
|
| 746 |
import traceback
|
| 747 |
+
|
| 748 |
traceback.print_exc()
|
| 749 |
audio_results.append(None)
|
| 750 |
+
|
| 751 |
# Clean up GPU memory
|
| 752 |
torch.cuda.empty_cache()
|
| 753 |
+
|
| 754 |
# Return text data and audio data
|
| 755 |
return actual_texts_data, audio_results
|
| 756 |
+
|
| 757 |
except Exception as e:
|
| 758 |
print(f"Error during batch processing: {str(e)}")
|
| 759 |
+
raise
|