Cqy2019 rulerman commited on
Commit
61bc2f6
·
verified ·
1 Parent(s): 82b777c

mossttsd-space (#2)

Browse files

- update (dbd498f5b6a8eb7aff9ea559070143b6f55c6315)
- update2 (8008b4069c3a9980a0137258cf9aab4f866c1d98)


Co-authored-by: zyq <rulerman@users.noreply.huggingface.co>

XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_params:
2
+ input_sample_rate: 16000
3
+ output_sample_rate: 32000
4
+ encoder_downsample_rate: 1280
5
+ decoder_upsample_rate: 2560
6
+
7
+ feature_extractor_kwargs:
8
+ chunk_length: 30
9
+ feature_size: 80
10
+ hop_length: 160
11
+ n_fft: 400
12
+ n_samples: 480000
13
+ nb_max_frames: 3000
14
+ padding_side: right
15
+ padding_value: 0.0
16
+ return_attention_mask: false
17
+ sampling_rate: 16000
18
+
19
+ # Codec / model architecture (inference required)
20
+ semantic_encoder_kwargs: # 100hz -> 50hz
21
+ num_mel_bins: 80
22
+ sampling_rate: 16000
23
+ hop_length: 160
24
+ stride_size: 2
25
+ kernel_size: 3
26
+ d_model: 768
27
+ scale_embedding: false
28
+ max_audio_seconds: 30
29
+ encoder_layers: 12
30
+ encoder_attention_heads: 12
31
+ encoder_ffn_dim: 3072
32
+ activation_function: "gelu"
33
+
34
+ semantic_encoder_adapter_kwargs: # 50hz
35
+ input_dim: 768
36
+ output_dim: 768
37
+ d_model: 768
38
+ max_source_positions: 1500
39
+ encoder_layers: 4
40
+ encoder_attention_heads: 12
41
+ encoder_ffn_dim: 3072
42
+
43
+ acoustic_encoder_kwargs: # 100hz -> 50hz
44
+ num_mel_bins: 80
45
+ sampling_rate: 16000
46
+ hop_length: 160
47
+ stride_size: 2
48
+ kernel_size: 3
49
+ d_model: 768
50
+ scale_embedding: false
51
+ max_audio_seconds: 30
52
+ encoder_layers: 12
53
+ encoder_attention_heads: 12
54
+ encoder_ffn_dim: 3072
55
+ activation_function: "gelu"
56
+
57
+ pre_rvq_adapter_kwargs: # 50hz
58
+ input_dim: 1536
59
+ output_dim: 768
60
+ d_model: 768
61
+ max_source_positions: 1500
62
+ encoder_layers: 4
63
+ encoder_attention_heads: 12
64
+ encoder_ffn_dim: 3072
65
+
66
+ downsample_kwargs: # 50hz -> 12.5hz
67
+ d_model: 768
68
+ avg_pooler: 4
69
+
70
+ quantizer_kwargs: # 12.5hz
71
+ input_dim: 3072
72
+ rvq_dim: 512
73
+ output_dim: 3072
74
+ num_quantizers: 8
75
+ codebook_size: 1024
76
+ codebook_dim: 512
77
+ quantizer_dropout: 0.0
78
+ commitment: 1
79
+
80
+ post_rvq_adapter_kwargs: # 12.5hz
81
+ input_dim: 3072
82
+ output_dim: 3072
83
+ d_model: 768
84
+ max_source_positions: 375
85
+ encoder_layers: 4
86
+ encoder_attention_heads: 12
87
+ encoder_ffn_dim: 3072
88
+
89
+ upsample_kwargs: # 12.5hz -> 50hz
90
+ d_model: 768
91
+ stride: 4
92
+
93
+ acoustic_decoder_kwargs: # 50hz -> 100hz
94
+ num_mel_bins: 80
95
+ sampling_rate: 16000
96
+ hop_length: 160
97
+ stride_size: 2
98
+ kernel_size: 3
99
+ d_model: 768
100
+ scale_embedding: false
101
+ max_audio_seconds: 30
102
+ decoder_layers: 12
103
+ decoder_attention_heads: 12
104
+ decoder_ffn_dim: 3072
105
+ activation_function: "gelu"
106
+
107
+ vocos_kwargs: # 100hz -> 32khz
108
+ input_channels: 80
109
+ dim: 512
110
+ intermediate_dim: 4096
111
+ num_layers: 30
112
+ n_fft: 1280
113
+ hop_size: 320
114
+ padding: "same"
XY_Tokenizer/xy_tokenizer/model.py CHANGED
@@ -1,146 +1,198 @@
1
  # -*- coding: utf-8 -*-
2
- import yaml
3
  import logging
 
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
-
8
 
9
  from .nn.feature_extractor import MelFeatureExtractor
10
- from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos
 
 
 
 
 
 
 
11
  from .nn.quantizer import ResidualVQ
12
 
 
13
  class XY_Tokenizer(nn.Module):
14
  def __init__(self, generator_params):
15
  super().__init__()
16
  # Basic parameters
17
- self.input_sample_rate = generator_params['input_sample_rate']
18
- self.output_sample_rate = generator_params['output_sample_rate']
19
-
20
- self.encoder_downsample_rate = 1280
21
- self.decoder_upsample_rate = 1920
22
- self.code_dim = generator_params['quantizer_kwargs']['input_dim']
23
-
24
  ## Codec part
25
 
26
  ## Semantic channel
27
- self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs'])
28
-
29
- self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs'])
30
-
 
 
 
 
31
  ## Acoustic channel
32
- self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs'])
33
-
 
 
34
  ## Semantic & acoustic shared parameters
35
- self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs'])
36
-
37
- self.downsample = ResidualDownConv(**generator_params['downsample_kwargs'])
38
-
39
- self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs'])
40
- self.nq = generator_params['quantizer_kwargs']['num_quantizers']
41
-
42
- self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs'])
43
-
 
 
44
  ## Acoustic channel
45
- self.upsample = UpConv(**generator_params['upsample_kwargs'])
46
 
47
- self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs'])
 
 
48
 
49
- self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
50
 
51
  ## Feature extractor
52
- self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs'])
 
53
 
54
  @torch.inference_mode()
55
  def inference_tokenize(self, x, input_lengths):
56
  """
57
- Input:
58
- x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
59
- input_lengths: Valid length for each sample # (B,)
60
- Output:
61
- dict: Contains the following key-value pairs
62
- "zq": Quantized embeddings # (B, D, T)
63
- "codes": Quantization codes # (nq, B, T)
64
- "codes_lengths": Quantization code lengths # (B,)
65
  """
66
- list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)]
 
 
 
67
  features = self.feature_extractor(
68
  list_x,
69
  sampling_rate=self.input_sample_rate,
70
  return_tensors="pt",
71
- return_attention_mask=True
72
  )
73
- input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000)
74
- audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000)
75
-
76
  # Get batch size and sequence length of the input
77
- mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
78
-
79
  # Semantic channel
80
- semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
81
-
82
- semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz
83
-
 
 
 
 
 
 
84
  # Acoustic channel
85
- acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
86
-
 
 
87
  # Semantic & acoustic mixing
88
- concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T)
 
 
89
  concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
90
-
91
- pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz
92
-
93
- downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz
94
 
95
- zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  return {
98
- "zq": zq, # (B, D, T)
99
- "codes": codes, # (nq, B, T)
100
- "codes_lengths": quantizer_output_length # (B,)
101
  }
102
-
103
- @torch.inference_mode()
104
  def inference_detokenize(self, codes, codes_lengths):
105
  """
106
- Input:
107
- codes: Quantization codes # (nq, B, T)
108
- codes_lengths: Quantization code lengths for each sample # (B,)
109
- Output:
110
- dict: Contains the following key-value pairs
111
- "y": Synthesized audio waveform # (B, 1, T)
112
- "output_length": Output lengths # (B,)
113
  """
114
- zq = self.quantizer.decode_codes(codes) # (B, D, T)
115
-
116
- post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz
117
-
118
- # Acoustic channel
119
- upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz
120
 
121
- acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz
124
-
125
  return {
126
- "y": y, # (B, 1, T)
127
- "output_length": vocos_output_length, # (B,)
128
  }
129
-
130
  @torch.inference_mode()
131
- def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")):
132
  """
133
- Input:
134
- wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
135
- overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
136
- Output:
137
- dict: Contains the following key-value pairs
138
- "codes_list": List of quantization codes # B * (nq, T)
139
  """
 
140
  duration_seconds = 30 - overlap_seconds
141
- chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
142
- duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk
143
- code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk
 
 
 
 
144
 
145
  # Get maximum waveform length
146
  max_length = max(len(wav) for wav in wav_list)
@@ -148,8 +200,8 @@ class XY_Tokenizer(nn.Module):
148
  wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
149
  input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
150
  for i, wav in enumerate(wav_list):
151
- wav_tensor[i, 0, :len(wav)] = wav
152
- input_lengths[i] = len(wav) # (B,)
153
 
154
  # Calculate number of chunks needed
155
  max_chunks = (max_length + duration_size - 1) // duration_size
@@ -159,121 +211,161 @@ class XY_Tokenizer(nn.Module):
159
  for chunk_idx in range(max_chunks):
160
  start = chunk_idx * duration_size
161
  end = min(start + chunk_size, max_length)
162
- chunk = wav_tensor[:, :, start:end] # (B, 1, T')
163
- chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
164
 
165
  # Skip empty chunks
166
  if chunk_lengths.max() == 0:
167
  continue
168
 
169
  # Encode
170
- result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
171
- chunk_codes = result["codes"] # (nq, B, T')
172
- chunk_code_lengths = result["codes_lengths"] # (B,)
 
 
173
 
174
  # Extract valid portion
175
- valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,)
176
- valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype)
 
 
 
 
 
 
 
 
177
  for b in range(batch_size):
178
  if valid_code_lengths[b] > 0:
179
- valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length)
 
 
180
 
181
- codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
182
 
183
  # Concatenate all chunks
184
  if codes_list:
185
- codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
186
- codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T)
 
 
 
187
  else:
188
- codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0)
 
 
 
 
 
189
 
190
- return {
191
- "codes_list": codes_list # B * (nq, T)
192
- }
193
-
194
  @torch.inference_mode()
195
- def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")):
196
  """
197
- Input:
198
- codes_list: List of quantization codes # B * (nq, T)
199
- overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
200
- Output:
201
- dict: Contains the following key-value pairs
202
- "syn_wav_list": List of synthesized audio waveforms # B * (T,)
203
  """
 
204
  duration_seconds = 30 - overlap_seconds
205
- chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk
206
- duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk
207
- duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk
 
 
 
 
 
 
208
 
209
  # Get maximum code length
210
  max_code_length = max(codes.shape[-1] for codes in codes_list)
211
  batch_size = len(codes_list)
212
- codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long)
 
 
213
  code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
214
  for i, codes in enumerate(codes_list):
215
- codes_tensor[:, i, :codes.shape[-1]] = codes.to(device)
216
- code_lengths[i] = codes.shape[-1] # (B,)
217
 
218
  # Calculate number of chunks needed
219
- max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
 
 
220
  wav_list = []
221
 
222
  # Process the entire batch in chunks
223
  for chunk_idx in range(max_chunks):
224
  start = chunk_idx * duration_code_length
225
  end = min(start + chunk_code_length, max_code_length)
226
- chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
227
- chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
 
 
228
 
229
  # Skip empty chunks
230
  if chunk_code_lengths.max() == 0:
231
  continue
232
 
233
  # Decode
234
- result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
235
- chunk_wav = result["y"] # (B, 1, T')
236
- chunk_wav_lengths = result["output_length"] # (B,)
 
 
237
 
238
  # Extract valid portion
239
- valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
240
- valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device)
 
 
 
 
241
  for b in range(batch_size):
242
  if valid_wav_lengths[b] > 0:
243
- valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
 
 
244
 
245
- wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
246
 
247
  # Concatenate all chunks
248
  if wav_list:
249
- wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
250
- syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,)
 
 
 
251
  else:
252
- syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,)
253
-
254
- return {
255
- "syn_wav_list": syn_wav_list # B * (T,)
256
- }
257
-
258
  @classmethod
259
  def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
260
  # Load model from configuration file and checkpoint
261
  logging.info(f"Loading model from {config_path} and {ckpt_path}")
262
-
263
  # Load configuration
264
- with open(config_path, 'r') as f:
265
  config = yaml.safe_load(f)
266
-
267
  # Create model instance
268
- model = cls(config['generator_params'])
269
-
270
  # Load checkpoint
271
- checkpoint = torch.load(ckpt_path, map_location='cpu')
272
-
273
  # Check if checkpoint contains 'generator' key
274
- if 'generator' in checkpoint:
275
- model.load_state_dict(checkpoint['generator'])
276
  else:
277
  model.load_state_dict(checkpoint)
278
-
279
- return model
 
1
  # -*- coding: utf-8 -*-
 
2
  import logging
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
+ import yaml
8
 
9
  from .nn.feature_extractor import MelFeatureExtractor
10
+ from .nn.modules import (
11
+ OmniAudioDecoder,
12
+ OmniAudioEncoder,
13
+ ResidualDownConv,
14
+ Transformer,
15
+ UpConv,
16
+ Vocos,
17
+ )
18
  from .nn.quantizer import ResidualVQ
19
 
20
+
21
  class XY_Tokenizer(nn.Module):
22
  def __init__(self, generator_params):
23
  super().__init__()
24
  # Basic parameters
25
+ self.input_sample_rate = generator_params["input_sample_rate"]
26
+ self.output_sample_rate = generator_params["output_sample_rate"]
27
+
28
+ self.encoder_downsample_rate = generator_params["encoder_downsample_rate"]
29
+ self.decoder_upsample_rate = generator_params["decoder_upsample_rate"]
30
+ self.code_dim = generator_params["quantizer_kwargs"]["input_dim"]
31
+
32
  ## Codec part
33
 
34
  ## Semantic channel
35
+ self.semantic_encoder = OmniAudioEncoder(
36
+ **generator_params["semantic_encoder_kwargs"]
37
+ )
38
+
39
+ self.semantic_encoder_adapter = Transformer(
40
+ **generator_params["semantic_encoder_adapter_kwargs"]
41
+ )
42
+
43
  ## Acoustic channel
44
+ self.acoustic_encoder = OmniAudioEncoder(
45
+ **generator_params["acoustic_encoder_kwargs"]
46
+ )
47
+
48
  ## Semantic & acoustic shared parameters
49
+ self.pre_rvq_adapter = Transformer(**generator_params["pre_rvq_adapter_kwargs"])
50
+
51
+ self.downsample = ResidualDownConv(**generator_params["downsample_kwargs"])
52
+
53
+ self.quantizer = ResidualVQ(**generator_params["quantizer_kwargs"])
54
+ self.nq = generator_params["quantizer_kwargs"]["num_quantizers"]
55
+
56
+ self.post_rvq_adapter = Transformer(
57
+ **generator_params["post_rvq_adapter_kwargs"]
58
+ )
59
+
60
  ## Acoustic channel
61
+ self.upsample = UpConv(**generator_params["upsample_kwargs"])
62
 
63
+ self.acoustic_decoder = OmniAudioDecoder(
64
+ **generator_params["acoustic_decoder_kwargs"]
65
+ )
66
 
67
+ self.enhanced_vocos = Vocos(**generator_params["vocos_kwargs"])
68
 
69
  ## Feature extractor
70
+ fe_kwargs = generator_params.get("feature_extractor_kwargs", {})
71
+ self.feature_extractor = MelFeatureExtractor(**fe_kwargs)
72
 
73
  @torch.inference_mode()
74
  def inference_tokenize(self, x, input_lengths):
75
  """
76
+ Input:
77
+ x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
78
+ input_lengths: Valid length for each sample # (B,)
79
+ Output:
80
+ dict: Contains the following key-value pairs
81
+ "zq": Quantized embeddings # (B, D, T)
82
+ "codes": Quantization codes # (nq, B, T)
83
+ "codes_lengths": Quantization code lengths # (B,)
84
  """
85
+ list_x = [
86
+ xi[:, :x_len].reshape(-1).cpu().numpy()
87
+ for xi, x_len in zip(x, input_lengths)
88
+ ]
89
  features = self.feature_extractor(
90
  list_x,
91
  sampling_rate=self.input_sample_rate,
92
  return_tensors="pt",
93
+ return_attention_mask=True,
94
  )
95
+ input_mel = features["input_features"].to(x.device).to(x.dtype) # (B, D, 3000)
96
+ audio_attention_mask = features["attention_mask"].to(x.device) # (B, 3000)
97
+
98
  # Get batch size and sequence length of the input
99
+ mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
100
+
101
  # Semantic channel
102
+ semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(
103
+ input_mel, mel_output_length
104
+ ) # (B, D, T), 100hz -> 50hz
105
+
106
+ semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = (
107
+ self.semantic_encoder_adapter(
108
+ semantic_encoder_output, semantic_encoder_output_length
109
+ )
110
+ ) # (B, D, T), 50hz
111
+
112
  # Acoustic channel
113
+ acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(
114
+ input_mel, mel_output_length
115
+ ) # (B, D, T), 100hz -> 50hz
116
+
117
  # Semantic & acoustic mixing
118
+ concated_semantic_acoustic_channel = torch.concat(
119
+ [semantic_encoder_adapter_output, acoustic_encoder_output], dim=1
120
+ ) # (B, D, T)
121
  concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
 
 
 
 
122
 
123
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(
124
+ concated_semantic_acoustic_channel,
125
+ concated_semantic_acoustic_channel_length,
126
+ ) # (B, D, T), 50hz
127
+
128
+ downsample_output, downsample_output_length = self.downsample(
129
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length
130
+ ) # (B, D, T), 50hz -> 12.5hz
131
+
132
+ zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(
133
+ downsample_output, downsample_output_length
134
+ ) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
135
 
136
  return {
137
+ "zq": zq, # (B, D, T)
138
+ "codes": codes, # (nq, B, T)
139
+ "codes_lengths": quantizer_output_length, # (B,)
140
  }
141
+
142
+ @torch.inference_mode()
143
  def inference_detokenize(self, codes, codes_lengths):
144
  """
145
+ Input:
146
+ codes: Quantization codes # (nq, B, T)
147
+ codes_lengths: Quantization code lengths for each sample # (B,)
148
+ Output:
149
+ dict: Contains the following key-value pairs
150
+ "y": Synthesized audio waveform # (B, 1, T)
151
+ "output_length": Output lengths # (B,)
152
  """
153
+ zq = self.quantizer.decode_codes(codes) # (B, D, T)
154
+
155
+ post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(
156
+ zq, codes_lengths
157
+ ) # (B, D, T), 12.5hz
 
158
 
159
+ # Acoustic channel
160
+ upsample_output, upsample_output_length = self.upsample(
161
+ post_rvq_adapter_output, post_rvq_adapter_output_length
162
+ ) # (B, D, T), 12.5hz -> 50hz
163
+
164
+ acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(
165
+ upsample_output, upsample_output_length
166
+ ) # (B, D, T), 50hz -> 100hz
167
+
168
+ y, vocos_output_length = self.enhanced_vocos(
169
+ acoustic_decoder_output, acoustic_decoder_output_length
170
+ ) # (B, 1, T), 100hz -> 16khz
171
 
 
 
172
  return {
173
+ "y": y, # (B, 1, T)
174
+ "output_length": vocos_output_length, # (B,)
175
  }
176
+
177
  @torch.inference_mode()
178
+ def encode(self, wav_list, overlap_seconds=10):
179
  """
180
+ Input:
181
+ wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
182
+ overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
183
+ Output:
184
+ dict: Contains the following key-value pairs
185
+ "codes_list": List of quantization codes # B * (nq, T)
186
  """
187
+ device = wav_list[0].device
188
  duration_seconds = 30 - overlap_seconds
189
+ chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
190
+ duration_size = int(
191
+ duration_seconds * self.input_sample_rate
192
+ ) # Valid output samples per chunk
193
+ code_duration_length = (
194
+ duration_size // self.encoder_downsample_rate
195
+ ) # Valid code length per chunk
196
 
197
  # Get maximum waveform length
198
  max_length = max(len(wav) for wav in wav_list)
 
200
  wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
201
  input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
202
  for i, wav in enumerate(wav_list):
203
+ wav_tensor[i, 0, : len(wav)] = wav
204
+ input_lengths[i] = len(wav) # (B,)
205
 
206
  # Calculate number of chunks needed
207
  max_chunks = (max_length + duration_size - 1) // duration_size
 
211
  for chunk_idx in range(max_chunks):
212
  start = chunk_idx * duration_size
213
  end = min(start + chunk_size, max_length)
214
+ chunk = wav_tensor[:, :, start:end] # (B, 1, T')
215
+ chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
216
 
217
  # Skip empty chunks
218
  if chunk_lengths.max() == 0:
219
  continue
220
 
221
  # Encode
222
+ result = self.inference_tokenize(
223
+ chunk, chunk_lengths
224
+ ) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
225
+ chunk_codes = result["codes"] # (nq, B, T')
226
+ chunk_code_lengths = result["codes_lengths"] # (B,)
227
 
228
  # Extract valid portion
229
+ valid_code_lengths = torch.clamp(
230
+ chunk_code_lengths, 0, code_duration_length
231
+ ) # (B,)
232
+ valid_chunk_codes = torch.zeros(
233
+ self.nq,
234
+ batch_size,
235
+ code_duration_length,
236
+ device=device,
237
+ dtype=chunk_codes.dtype,
238
+ )
239
  for b in range(batch_size):
240
  if valid_code_lengths[b] > 0:
241
+ valid_chunk_codes[:, b, : valid_code_lengths[b]] = chunk_codes[
242
+ :, b, : valid_code_lengths[b]
243
+ ] # (nq, B, valid_code_length)
244
 
245
+ codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
246
 
247
  # Concatenate all chunks
248
  if codes_list:
249
+ codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
250
+ codes_list = [
251
+ codes_tensor[:, i, : input_lengths[i] // self.encoder_downsample_rate]
252
+ for i in range(batch_size)
253
+ ] # B * (nq, T)
254
  else:
255
+ codes_list = [
256
+ torch.zeros(self.nq, 0, device=device, dtype=torch.long)
257
+ for _ in range(batch_size)
258
+ ] # B * (nq, 0)
259
+
260
+ return {"codes_list": codes_list} # B * (nq, T)
261
 
 
 
 
 
262
  @torch.inference_mode()
263
+ def decode(self, codes_list, overlap_seconds=10):
264
  """
265
+ Input:
266
+ codes_list: List of quantization codes # B * (nq, T)
267
+ overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
268
+ Output:
269
+ dict: Contains the following key-value pairs
270
+ "syn_wav_list": List of synthesized audio waveforms # B * (T,)
271
  """
272
+ device = codes_list[0].device
273
  duration_seconds = 30 - overlap_seconds
274
+ chunk_code_length = int(
275
+ 30 * self.input_sample_rate // self.encoder_downsample_rate
276
+ ) # Maximum code length per chunk
277
+ duration_code_length = int(
278
+ duration_seconds * self.input_sample_rate // self.encoder_downsample_rate
279
+ ) # Valid code length per chunk
280
+ duration_wav_length = (
281
+ duration_code_length * self.decoder_upsample_rate
282
+ ) # Valid waveform length per chunk
283
 
284
  # Get maximum code length
285
  max_code_length = max(codes.shape[-1] for codes in codes_list)
286
  batch_size = len(codes_list)
287
+ codes_tensor = torch.zeros(
288
+ self.nq, batch_size, max_code_length, device=device, dtype=torch.long
289
+ )
290
  code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
291
  for i, codes in enumerate(codes_list):
292
+ codes_tensor[:, i, : codes.shape[-1]] = codes.to(device)
293
+ code_lengths[i] = codes.shape[-1] # (B,)
294
 
295
  # Calculate number of chunks needed
296
+ max_chunks = (
297
+ max_code_length + duration_code_length - 1
298
+ ) // duration_code_length
299
  wav_list = []
300
 
301
  # Process the entire batch in chunks
302
  for chunk_idx in range(max_chunks):
303
  start = chunk_idx * duration_code_length
304
  end = min(start + chunk_code_length, max_code_length)
305
+ chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
306
+ chunk_code_lengths = torch.clamp(
307
+ code_lengths - start, 0, end - start
308
+ ) # (B,)
309
 
310
  # Skip empty chunks
311
  if chunk_code_lengths.max() == 0:
312
  continue
313
 
314
  # Decode
315
+ result = self.inference_detokenize(
316
+ chunk_codes, chunk_code_lengths
317
+ ) # {"y": (B, 1, T'), "output_length": (B,)}
318
+ chunk_wav = result["y"] # (B, 1, T')
319
+ chunk_wav_lengths = result["output_length"] # (B,)
320
 
321
  # Extract valid portion
322
+ valid_wav_lengths = torch.clamp(
323
+ chunk_wav_lengths, 0, duration_wav_length
324
+ ) # (B,)
325
+ valid_chunk_wav = torch.zeros(
326
+ batch_size, 1, duration_wav_length, device=device
327
+ )
328
  for b in range(batch_size):
329
  if valid_wav_lengths[b] > 0:
330
+ valid_chunk_wav[b, :, : valid_wav_lengths[b]] = chunk_wav[
331
+ b, :, : valid_wav_lengths[b]
332
+ ] # (B, 1, valid_wav_length)
333
 
334
+ wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
335
 
336
  # Concatenate all chunks
337
  if wav_list:
338
+ wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
339
+ syn_wav_list = [
340
+ wav_tensor[i, 0, : code_lengths[i] * self.decoder_upsample_rate]
341
+ for i in range(batch_size)
342
+ ] # B * (T,)
343
  else:
344
+ syn_wav_list = [
345
+ torch.zeros(0, device=device) for _ in range(batch_size)
346
+ ] # B * (0,)
347
+
348
+ return {"syn_wav_list": syn_wav_list} # B * (T,)
349
+
350
  @classmethod
351
  def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
352
  # Load model from configuration file and checkpoint
353
  logging.info(f"Loading model from {config_path} and {ckpt_path}")
354
+
355
  # Load configuration
356
+ with open(config_path, "r") as f:
357
  config = yaml.safe_load(f)
358
+
359
  # Create model instance
360
+ model = cls(config["generator_params"])
361
+
362
  # Load checkpoint
363
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
364
+
365
  # Check if checkpoint contains 'generator' key
366
+ if "generator" in checkpoint:
367
+ model.load_state_dict(checkpoint["generator"])
368
  else:
369
  model.load_state_dict(checkpoint)
370
+
371
+ return model
app.py CHANGED
@@ -131,15 +131,15 @@ LANGUAGES = {
131
  # Model configuration
132
  SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text."
133
  MODEL_PATH = os.environ["MODEL_REPO_ID"]
134
- SPT_CONFIG_PATH = "XY_Tokenizer/config/xy_tokenizer_config.yaml"
135
  # SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt"
136
  MAX_CHANNELS = 8
137
 
138
  from huggingface_hub import hf_hub_download
139
 
140
  SPT_CHECKPOINT_PATH = hf_hub_download(
141
- repo_id="fnlp/XY_Tokenizer_TTSD_V0",
142
- filename="xy_tokenizer.ckpt",
143
  cache_dir="XY_Tokenizer/weights"
144
  )
145
 
@@ -245,7 +245,8 @@ def process_single_audio_generation(
245
  device=device,
246
  system_prompt=SYSTEM_PROMPT,
247
  start_idx=0,
248
- use_normalize=use_normalize
 
249
  )
250
 
251
  # Check results
 
131
  # Model configuration
132
  SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text."
133
  MODEL_PATH = os.environ["MODEL_REPO_ID"]
134
+ SPT_CONFIG_PATH = "XY_Tokenizer/config/MOSS_TTSD_tokenizer.yaml"
135
  # SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt"
136
  MAX_CHANNELS = 8
137
 
138
  from huggingface_hub import hf_hub_download
139
 
140
  SPT_CHECKPOINT_PATH = hf_hub_download(
141
+ repo_id="OpenMOSS-Team/MOSS_TTSD_tokenizer",
142
+ filename="MOSS_TTSD_tokenizer",
143
  cache_dir="XY_Tokenizer/weights"
144
  )
145
 
 
245
  device=device,
246
  system_prompt=SYSTEM_PROMPT,
247
  start_idx=0,
248
+ use_normalize=use_normalize,
249
+ silence_duration=0.1,
250
  )
251
 
252
  # Check results
generation_utils.py CHANGED
@@ -1,86 +1,181 @@
1
  import os
2
  import re
3
 
 
4
  import torch
5
  import torchaudio
6
- import numpy as np
7
-
8
- from transformers import AutoTokenizer
9
- from modeling_asteroid import AsteroidTTSInstruct
10
- from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
11
 
12
  MAX_CHANNELS = 8
13
- SILENCE_DURATION = 0.0 # Fixed silence duration: 0 seconds
14
 
15
- def load_model(model_path, spt_config_path, spt_checkpoint_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_path)
17
-
18
- model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="sdpa")
 
 
 
 
19
 
20
- spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
21
-
22
  model.eval()
23
  spt.eval()
24
  return tokenizer, model, spt
25
 
26
 
27
  def process_jsonl_item(item):
28
- """Process JSONL data items and extract audio and text information according to the new format"""
29
- base_path = item.get("base_path", "")
 
 
 
 
 
30
  text = item.get("text", "")
31
-
32
- # Process prompt audio and text
33
- if "prompt_audio" in item and "prompt_text" in item:
34
- print("Using prompt_audio and prompt_text directly from item.")
35
- # If prompt_audio and prompt_text exist, use them directly
36
- prompt_audio = item["prompt_audio"]
37
- prompt_text = item["prompt_text"]
38
-
39
- # Only perform path joining when prompt_audio is a string path
40
- if isinstance(prompt_audio, str) and base_path and prompt_audio:
41
- prompt_audio = os.path.join(base_path, prompt_audio)
42
- else:
43
- print("Using speaker1 and speaker2 information for prompt audio and text.")
44
- # Otherwise, merge speaker1 and speaker2 information
45
- prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "")
46
- prompt_text_speaker1 = item.get("prompt_text_speaker1", "")
47
- prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "")
48
- prompt_text_speaker2 = item.get("prompt_text_speaker2", "")
49
-
50
- # Process audio: if it's a string path, perform path joining; if it's a tuple, use directly
51
- if isinstance(prompt_audio_speaker1, str):
52
- speaker1_audio = os.path.join(base_path, prompt_audio_speaker1) if base_path and prompt_audio_speaker1 else prompt_audio_speaker1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  else:
54
- speaker1_audio = prompt_audio_speaker1 # Use tuple directly
55
-
56
- if isinstance(prompt_audio_speaker2, str):
57
- speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2
58
  else:
59
- speaker2_audio = prompt_audio_speaker2 # Use tuple directly
60
-
61
- prompt_audio = {
62
- "speaker1": speaker1_audio,
63
- "speaker2": speaker2_audio
64
- }
65
-
66
- # Merge text
67
- prompt_text = ""
68
- if prompt_text_speaker1:
69
- prompt_text += f"[S1]{prompt_text_speaker1}"
70
- if prompt_text_speaker2:
71
- prompt_text += f"[S2]{prompt_text_speaker2}"
72
- prompt_text = prompt_text.strip()
73
-
74
- return {
75
- "text": text,
76
- "prompt_text": prompt_text,
77
- "prompt_audio": prompt_audio
78
- }
79
 
80
 
81
  def load_audio_data(prompt_audio, target_sample_rate=16000):
82
  """Load audio data and return processed audio tensor
83
-
84
  Args:
85
  prompt_audio: Can be in the following formats:
86
  - String: audio file path
@@ -89,10 +184,14 @@ def load_audio_data(prompt_audio, target_sample_rate=16000):
89
  """
90
  if prompt_audio is None:
91
  return None
92
-
93
  try:
94
  # Check if prompt_audio is a dictionary (containing speaker1 and speaker2)
95
- if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio:
 
 
 
 
96
  # Process audio from both speakers separately
97
  wav1, sr1 = _load_single_audio(prompt_audio["speaker1"])
98
  wav2, sr2 = _load_single_audio(prompt_audio["speaker2"])
@@ -104,14 +203,14 @@ def load_audio_data(prompt_audio, target_sample_rate=16000):
104
  # Single audio
105
  wav, sr = _load_single_audio(prompt_audio)
106
  # Resample to 16k
107
- if sr != target_sample_rate:
108
  wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
109
  # Ensure mono channel
110
  if wav.shape[0] > 1:
111
  wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono
112
- if len(wav.shape) == 1:
113
  wav = wav.unsqueeze(0)
114
-
115
  return wav
116
  except Exception as e:
117
  print(f"Error loading audio data: {e}")
@@ -120,10 +219,10 @@ def load_audio_data(prompt_audio, target_sample_rate=16000):
120
 
121
  def _load_single_audio(audio_input):
122
  """Load single audio, supports file path or (wav, sr) tuple
123
-
124
  Args:
125
  audio_input: String (file path) or tuple (wav, sr)
126
-
127
  Returns:
128
  tuple: (wav, sr)
129
  """
@@ -150,8 +249,8 @@ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
150
  wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono
151
  if len(wav1.shape) == 1:
152
  wav1 = wav1.unsqueeze(0)
153
-
154
- # Process second audio
155
  if sr2 != target_sample_rate:
156
  wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate)
157
  # Ensure mono channel
@@ -159,7 +258,7 @@ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
159
  wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono
160
  if len(wav2.shape) == 1:
161
  wav2 = wav2.unsqueeze(0)
162
-
163
  # Concatenate audio
164
  merged_wav = torch.cat([wav1, wav2], dim=1)
165
  return merged_wav
@@ -168,34 +267,48 @@ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
168
  raise
169
 
170
 
171
- def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, max_channels=8, pad_token=1024):
 
 
 
 
 
 
 
 
 
 
172
  seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>"
173
  inputs1 = np.array(tokenizer.encode(seq))
174
  input_ids = np.full((inputs1.shape[0], max_channels), pad_token)
175
  input_ids[:, 0] = inputs1
176
-
177
  if audio_data is not None:
178
  try:
179
  # audio_data should now be a processed audio tensor
180
  wav = audio_data
181
-
182
  # Add fixed 5-second silence at the end of audio (using 16k sample rate)
183
- silence_samples = int(SILENCE_DURATION * 16000)
184
  silence = torch.zeros(wav.shape[0], silence_samples)
185
  wav = torch.cat([wav, silence], dim=1)
186
-
187
  with torch.no_grad():
188
  # Use SPT encoding
189
  encode_result = spt.encode([wav.squeeze().to(device)])
190
- audio_token = encode_result["codes_list"][0].permute(1, 0).cpu().numpy() # Adjust dimension order
191
-
 
 
192
  # similar to DAC encoding adjustment
193
- audio_token[:, 0] = audio_token[:, 0] + 151665 # Keep this line if offset is needed, otherwise delete
 
 
194
  input_ids = np.concatenate([input_ids, audio_token])
195
  except Exception as e:
196
  print(f"Error processing audio data: {e}")
197
  raise
198
-
199
  return input_ids
200
 
201
 
@@ -203,7 +316,9 @@ def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8):
203
  seq_len = input_ids.shape[0]
204
  new_seq_len = seq_len + max_channels - 1
205
  shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64)
206
- shifted_input_ids[:, 0] = np.full(new_seq_len, tokenizer.pad_token_id, dtype=np.int64)
 
 
207
  for i in range(max_channels):
208
  shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i]
209
  return shifted_input_ids
@@ -213,7 +328,7 @@ def rpadding(input_ids, channels, tokenizer):
213
  attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids]
214
  max_length = max(ids.shape[0] for ids in input_ids)
215
  padded_input_ids, padded_attns = [], []
216
-
217
  for ids, attn in zip(input_ids, attention_masks):
218
  pad_len = max_length - ids.shape[0]
219
  input_pad = np.full((pad_len, channels), 1024)
@@ -245,26 +360,23 @@ def normalize_text(text: str) -> str:
245
  Normalize multi-speaker script.
246
 
247
  1. Don't preserve line breaks.
248
- 2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves).
249
- 3. Remove decorative symbols: 【】《》()『』「」"-“” .
250
- 4. Internal punctuation !;:、 → ,;only allow ? and ,。
251
  5. Multiple 。 keep only the last one, others → ,。
252
  6. Replace consecutive "哈" (>=2) with "(笑)".
253
  7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
 
254
  """
255
  # Replace [1], [2] etc. format with [S1], [S2] etc. format
256
- text = re.sub(r'\[(\d+)\]', r'[S\1]', text)
257
 
258
  # Remove decorative characters
259
- remove_chars = "【】《》()『』「」""\"-“”"
260
-
261
-
262
- # Remove brackets for non-speaker tags (keep content, only remove brackets themselves)
263
- text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text)
264
 
265
  # Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
266
- segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " "))
267
- normalized_lines = []
268
 
269
  for seg in segments:
270
  seg = seg.strip()
@@ -272,42 +384,73 @@ def normalize_text(text: str) -> str:
272
  continue
273
 
274
  # Extract tags
275
- m = re.match(r'^(\[S\d+\])\s*(.*)', seg)
276
- tag, content = m.groups() if m else ('', seg)
277
 
278
  # Remove irrelevant symbols
279
  content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
280
 
281
  # Handle consecutive "哈" characters: replace 2 or more with "(笑)"
282
- content = re.sub(r'哈{2,}', '()', content)
 
 
 
283
 
284
  # First handle multi-character punctuation marks
285
- content = content.replace('——', '')
286
- content = content.replace('……', '')
287
 
288
  # Handle single-character internal punctuation marks
289
- internal_punct_map = str.maketrans({
290
- '!': '', '!': ',',
291
- ';': ',', ';': ',',
292
- ':': ',', ':': ',',
293
- '、': ',',
294
- '?': ',', '?': ','
295
- })
296
  content = content.translate(internal_punct_map)
297
  content = content.strip()
298
 
299
  # Keep only the final period
300
  if len(content) > 1:
301
- last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
302
- body = content[:-1].replace('', ',')
 
 
 
 
303
  content = body + last_ch
304
 
305
- normalized_lines.append(f"{tag}{content}".strip())
306
 
307
- return "".join(normalized_lines)
 
308
 
 
 
 
 
309
 
310
- def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, start_idx, use_normalize=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  """Process a batch of data items and generate audio, return audio data and metadata"""
312
  try:
313
  # Prepare batch data
@@ -316,64 +459,74 @@ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, sta
316
  prompts = [system_prompt] * batch_size
317
  prompt_audios = []
318
  actual_texts_data = [] # Store actual text data used
319
-
320
  print(f"Processing {batch_size} samples starting from index {start_idx}...")
321
-
322
  # Extract text and audio from each sample
323
  for i, item in enumerate(batch_items):
324
  # Use new processing function
325
  processed_item = process_jsonl_item(item)
326
-
327
  text = processed_item["text"]
328
  prompt_text = processed_item["prompt_text"]
329
-
330
- # Merge text
331
- full_text = prompt_text + text
332
  original_full_text = full_text # Save original text
333
-
334
  # Apply text normalization based on parameter
335
  if use_normalize:
336
  full_text = normalize_text(full_text)
337
-
338
  # Replace speaker tags
339
- final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>")
 
 
340
  texts.append(final_text)
341
-
342
  # Save actual text information used
343
- actual_texts_data.append({
344
- "index": start_idx + i,
345
- "original_text": original_full_text,
346
- "normalized_text": normalize_text(original_full_text) if use_normalize else None,
347
- "final_text": final_text,
348
- "use_normalize": use_normalize
349
- })
350
-
 
 
 
 
351
  # Get reference audio
352
  prompt_audios.append(processed_item["prompt_audio"])
353
-
354
  # Process inputs
355
  input_ids_list = []
356
- for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)):
 
 
357
  # Load audio data here
358
  audio_data = load_audio_data(audio_path) if audio_path else None
359
- inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data)
 
 
360
  inputs = shifting_inputs(inputs, tokenizer)
361
  input_ids_list.append(inputs)
362
-
363
  # Pad batch inputs
364
  input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer)
365
-
366
  # Batch generation
367
  print(f"Starting batch audio generation...")
368
  start = input_ids.shape[1] - MAX_CHANNELS + 1
369
-
370
  # Move inputs to GPU
371
  input_ids = input_ids.to(device)
372
  attention_mask = attention_mask.to(device)
373
-
374
  # Generate model outputs
375
  outputs = model.generate(
376
- input_ids=input_ids,
377
  attention_mask=attention_mask,
378
  )
379
  print(f"Original outputs shape: {outputs.shape}")
@@ -385,20 +538,19 @@ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, sta
385
  outputs = outputs[:, start:]
386
  seq_len = outputs.shape[1] - MAX_CHANNELS + 1
387
  speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device)
388
-
389
-
390
  # Adjust output format
391
  for j in range(MAX_CHANNELS):
392
  speech_ids[..., j] = outputs[:, j : seq_len + j, j]
393
- if j == 0:
394
  speech_ids[..., j] = speech_ids[..., j] - 151665
395
-
396
  # Find valid positions for each sample
397
  li = find_max_valid_positions(speech_ids)
398
-
399
  # Store audio result data
400
  audio_results = []
401
-
402
  # Process batch sample results individually
403
  for i in range(batch_size):
404
  try:
@@ -408,39 +560,200 @@ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, sta
408
  print(f"Sample {start_idx + i} has no valid speech tokens")
409
  audio_results.append(None)
410
  continue
411
-
412
  this_speech_id = speech_ids[i, :end_idx]
413
- print(f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}")
414
-
415
- # Decode generated audio
416
- with torch.no_grad():
417
- codes_list = [this_speech_id.permute(1, 0)] # Convert to SPT expected format
418
- decode_result = spt.decode(codes_list, overlap_seconds=10)
419
- audio_result = decode_result["syn_wav_list"][0].cpu().detach()
420
-
421
- if audio_result.ndim == 1: # If 1D [samples]
422
- audio_result = audio_result.unsqueeze(0) # Convert to 2D [1, samples]
423
-
424
- # Save audio data instead of file path
425
- audio_results.append({
426
- "audio_data": audio_result,
427
- "sample_rate": spt.output_sample_rate,
428
- "index": start_idx + i
429
- })
430
- print(f"Audio generation completed: sample {start_idx + i}")
431
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
  except Exception as e:
433
  print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...")
434
  import traceback
 
435
  traceback.print_exc()
436
  audio_results.append(None)
437
-
438
  # Clean up GPU memory
439
  torch.cuda.empty_cache()
440
-
441
  # Return text data and audio data
442
  return actual_texts_data, audio_results
443
-
444
  except Exception as e:
445
  print(f"Error during batch processing: {str(e)}")
446
- raise
 
1
  import os
2
  import re
3
 
4
+ import numpy as np
5
  import torch
6
  import torchaudio
 
 
 
 
 
7
 
8
  MAX_CHANNELS = 8
 
9
 
10
+ def pad_or_truncate_to_seconds(
11
+ wav: torch.Tensor, target_seconds: float, sr: int
12
+ ) -> torch.Tensor:
13
+ """Pad or truncate a mono waveform to target length in seconds.
14
+
15
+ Args:
16
+ wav: (1, T) or (T,) tensor
17
+ target_seconds: target duration in seconds
18
+ sr: sample rate
19
+ Returns:
20
+ (1, T_target) tensor
21
+ """
22
+ if wav.dim() == 2 and wav.shape[0] == 1:
23
+ wav_1d = wav.squeeze(0)
24
+ else:
25
+ wav_1d = wav.reshape(-1)
26
+ target_len = int(round(target_seconds * sr))
27
+ cur_len = wav_1d.shape[-1]
28
+ if cur_len == target_len:
29
+ out = wav_1d
30
+ elif cur_len > target_len:
31
+ out = wav_1d[:target_len]
32
+ else:
33
+ pad_len = target_len - cur_len
34
+ out = torch.cat(
35
+ [wav_1d, torch.zeros(pad_len, dtype=wav_1d.dtype, device=wav_1d.device)],
36
+ dim=-1,
37
+ )
38
+ return out.unsqueeze(0)
39
+
40
+
41
+ def crossfade_concat(
42
+ segments: list, sample_rate: int, crossfade_seconds: float = 0.1
43
+ ) -> torch.Tensor:
44
+ """Concatenate segments with linear crossfade.
45
+
46
+ Args:
47
+ segments: list of (1, T) tensors
48
+ sample_rate: sampling rate
49
+ crossfade_seconds: overlap time for crossfade
50
+ Returns:
51
+ (1, T_total) tensor
52
+ """
53
+ if len(segments) == 0:
54
+ return torch.zeros(1, 0)
55
+ if len(segments) == 1:
56
+ return segments[0]
57
+ out = segments[0]
58
+ cf_len_target = int(round(crossfade_seconds * sample_rate))
59
+ for k in range(1, len(segments)):
60
+ nxt = segments[k]
61
+ if cf_len_target <= 0:
62
+ out = torch.cat([out, nxt], dim=-1)
63
+ continue
64
+ cf_len = min(cf_len_target, out.shape[-1], nxt.shape[-1])
65
+ if cf_len <= 0:
66
+ out = torch.cat([out, nxt], dim=-1)
67
+ continue
68
+ fade_out = torch.linspace(
69
+ 1.0, 0.0, steps=cf_len, dtype=out.dtype, device=out.device
70
+ )
71
+ fade_in = torch.linspace(
72
+ 0.0, 1.0, steps=cf_len, dtype=nxt.dtype, device=nxt.device
73
+ )
74
+ overlap = out[0, -cf_len:] * fade_out + nxt[0, :cf_len] * fade_in
75
+ out = torch.cat(
76
+ [out[:, :-cf_len], overlap.unsqueeze(0), nxt[:, cf_len:]], dim=-1
77
+ )
78
+ return out
79
+
80
+ def load_model(
81
+ model_path,
82
+ spt_config_path,
83
+ spt_checkpoint_path,
84
+ torch_dtype=torch.bfloat16,
85
+ attn_implementation="sdpa",
86
+ ):
87
+ from transformers import AutoTokenizer
88
+
89
+ from modeling_asteroid import AsteroidTTSInstruct
90
+ from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
91
+
92
  tokenizer = AutoTokenizer.from_pretrained(model_path)
93
+ model = AsteroidTTSInstruct.from_pretrained(
94
+ model_path, torch_dtype=torch_dtype, attn_implementation=attn_implementation
95
+ )
96
+ spt = XY_Tokenizer.load_from_checkpoint(
97
+ config_path=spt_config_path, ckpt_path=spt_checkpoint_path
98
+ )
99
 
 
 
100
  model.eval()
101
  spt.eval()
102
  return tokenizer, model, spt
103
 
104
 
105
  def process_jsonl_item(item):
106
+ """Parse a JSONL item enforcing prompt requirement.
107
+
108
+ Only supports Format 1 (separate speaker refs) and Format 2 (shared ref),
109
+ consistent with the updated README. If `base_path` is missing/empty, any
110
+ string paths must be absolute. Text-only input is not supported and will raise.
111
+ """
112
+ base_path = item.get("base_path", "") or ""
113
  text = item.get("text", "")
114
+
115
+ def _resolve_path(p: str) -> str:
116
+ if not isinstance(p, str) or not p:
117
+ return p
118
+ if base_path:
119
+ return os.path.join(base_path, p)
120
+ # base_path missing: require absolute path
121
+ if not os.path.isabs(p):
122
+ raise ValueError(
123
+ "When base_path is omitted, audio paths must be absolute. Got: " + p
124
+ )
125
+ return p
126
+
127
+ # Try Format 2 first: shared audio reference
128
+ prompt_audio = None
129
+ prompt_text = ""
130
+ if "prompt_audio" in item:
131
+ prompt_audio_val = item.get("prompt_audio")
132
+ if not prompt_audio_val:
133
+ raise ValueError("Format 2 requires non-empty 'prompt_audio'.")
134
+ if isinstance(prompt_audio_val, str):
135
+ prompt_audio = _resolve_path(prompt_audio_val)
136
+ else:
137
+ # allow tuple form for backward-compatibility
138
+ prompt_audio = prompt_audio_val
139
+ prompt_text = item.get("prompt_text", "")
140
+ return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
141
+
142
+ # Try Format 1: separate speaker references
143
+ s1 = item.get("prompt_audio_speaker1", "")
144
+ s2 = item.get("prompt_audio_speaker2", "")
145
+ has_s1 = (isinstance(s1, str) and s1) or isinstance(s1, tuple)
146
+ has_s2 = (isinstance(s2, str) and s2) or isinstance(s2, tuple)
147
+
148
+ if has_s1 and has_s2:
149
+ if isinstance(s1, str) and s1:
150
+ s1_resolved = _resolve_path(s1)
151
  else:
152
+ s1_resolved = s1
153
+ if isinstance(s2, str) and s2:
154
+ s2_resolved = _resolve_path(s2)
 
155
  else:
156
+ s2_resolved = s2
157
+ # Build merged prompt audio dict
158
+ prompt_audio = {"speaker1": s1_resolved, "speaker2": s2_resolved}
159
+ # Merge texts
160
+ pt1 = item.get("prompt_text_speaker1", "")
161
+ pt2 = item.get("prompt_text_speaker2", "")
162
+ merged = ""
163
+ if pt1:
164
+ merged += f"[S1]{pt1}"
165
+ if pt2:
166
+ merged += f"[S2]{pt2}"
167
+ prompt_text = merged.strip()
168
+ return {"text": text, "prompt_text": prompt_text, "prompt_audio": prompt_audio}
169
+
170
+ # Otherwise, no supported prompt found → reject (text-only unsupported)
171
+ raise ValueError(
172
+ "Input must include prompt (Format 1 or 2). Text-only is not supported."
173
+ )
 
 
174
 
175
 
176
  def load_audio_data(prompt_audio, target_sample_rate=16000):
177
  """Load audio data and return processed audio tensor
178
+
179
  Args:
180
  prompt_audio: Can be in the following formats:
181
  - String: audio file path
 
184
  """
185
  if prompt_audio is None:
186
  return None
187
+
188
  try:
189
  # Check if prompt_audio is a dictionary (containing speaker1 and speaker2)
190
+ if (
191
+ isinstance(prompt_audio, dict)
192
+ and "speaker1" in prompt_audio
193
+ and "speaker2" in prompt_audio
194
+ ):
195
  # Process audio from both speakers separately
196
  wav1, sr1 = _load_single_audio(prompt_audio["speaker1"])
197
  wav2, sr2 = _load_single_audio(prompt_audio["speaker2"])
 
203
  # Single audio
204
  wav, sr = _load_single_audio(prompt_audio)
205
  # Resample to 16k
206
+ if sr != target_sample_rate:
207
  wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
208
  # Ensure mono channel
209
  if wav.shape[0] > 1:
210
  wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono
211
+ if len(wav.shape) == 1:
212
  wav = wav.unsqueeze(0)
213
+
214
  return wav
215
  except Exception as e:
216
  print(f"Error loading audio data: {e}")
 
219
 
220
  def _load_single_audio(audio_input):
221
  """Load single audio, supports file path or (wav, sr) tuple
222
+
223
  Args:
224
  audio_input: String (file path) or tuple (wav, sr)
225
+
226
  Returns:
227
  tuple: (wav, sr)
228
  """
 
249
  wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono
250
  if len(wav1.shape) == 1:
251
  wav1 = wav1.unsqueeze(0)
252
+
253
+ # Process second audio
254
  if sr2 != target_sample_rate:
255
  wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate)
256
  # Ensure mono channel
 
258
  wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono
259
  if len(wav2.shape) == 1:
260
  wav2 = wav2.unsqueeze(0)
261
+
262
  # Concatenate audio
263
  merged_wav = torch.cat([wav1, wav2], dim=1)
264
  return merged_wav
 
267
  raise
268
 
269
 
270
+ def process_inputs(
271
+ tokenizer,
272
+ spt,
273
+ prompt,
274
+ text,
275
+ device,
276
+ silence_duration,
277
+ audio_data=None,
278
+ max_channels=8,
279
+ pad_token=1024,
280
+ ):
281
  seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>"
282
  inputs1 = np.array(tokenizer.encode(seq))
283
  input_ids = np.full((inputs1.shape[0], max_channels), pad_token)
284
  input_ids[:, 0] = inputs1
285
+
286
  if audio_data is not None:
287
  try:
288
  # audio_data should now be a processed audio tensor
289
  wav = audio_data
290
+
291
  # Add fixed 5-second silence at the end of audio (using 16k sample rate)
292
+ silence_samples = int(silence_duration * 16000)
293
  silence = torch.zeros(wav.shape[0], silence_samples)
294
  wav = torch.cat([wav, silence], dim=1)
295
+
296
  with torch.no_grad():
297
  # Use SPT encoding
298
  encode_result = spt.encode([wav.squeeze().to(device)])
299
+ audio_token = (
300
+ encode_result["codes_list"][0].permute(1, 0).cpu().numpy()
301
+ ) # Adjust dimension order
302
+
303
  # similar to DAC encoding adjustment
304
+ audio_token[:, 0] = (
305
+ audio_token[:, 0] + 151665
306
+ ) # Keep this line if offset is needed, otherwise delete
307
  input_ids = np.concatenate([input_ids, audio_token])
308
  except Exception as e:
309
  print(f"Error processing audio data: {e}")
310
  raise
311
+
312
  return input_ids
313
 
314
 
 
316
  seq_len = input_ids.shape[0]
317
  new_seq_len = seq_len + max_channels - 1
318
  shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64)
319
+ shifted_input_ids[:, 0] = np.full(
320
+ new_seq_len, tokenizer.pad_token_id, dtype=np.int64
321
+ )
322
  for i in range(max_channels):
323
  shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i]
324
  return shifted_input_ids
 
328
  attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids]
329
  max_length = max(ids.shape[0] for ids in input_ids)
330
  padded_input_ids, padded_attns = [], []
331
+
332
  for ids, attn in zip(input_ids, attention_masks):
333
  pad_len = max_length - ids.shape[0]
334
  input_pad = np.full((pad_len, channels), 1024)
 
360
  Normalize multi-speaker script.
361
 
362
  1. Don't preserve line breaks.
363
+ 2. Preserve bracketed segments like [] () <> even when they are not speaker tags.
364
+ 3. Remove decorative symbols: 【】《》()『』「」~~-_.
365
+ 4. Internal punctuation ;:、 → ,;keep ?!?.
366
  5. Multiple 。 keep only the last one, others → ,。
367
  6. Replace consecutive "哈" (>=2) with "(笑)".
368
  7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
369
+ 8. Merge adjacent identical speaker tags.
370
  """
371
  # Replace [1], [2] etc. format with [S1], [S2] etc. format
372
+ text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
373
 
374
  # Remove decorative characters
375
+ remove_chars = "【】《》()『』「」" '"-_“”~~‘’'
 
 
 
 
376
 
377
  # Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
378
+ segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
379
+ processed_parts = []
380
 
381
  for seg in segments:
382
  seg = seg.strip()
 
384
  continue
385
 
386
  # Extract tags
387
+ m = re.match(r"^(\[S\d+\])\s*(.*)", seg)
388
+ tag, content = m.groups() if m else ("", seg)
389
 
390
  # Remove irrelevant symbols
391
  content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
392
 
393
  # Handle consecutive "哈" characters: replace 2 or more with "(笑)"
394
+ content = re.sub(r"哈{2,}", "[]", content)
395
+
396
+ # Handle English laughter (e.g., "haha", "ha ha")
397
+ content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE)
398
 
399
  # First handle multi-character punctuation marks
400
+ content = content.replace("——", "")
401
+ content = content.replace("……", "")
402
 
403
  # Handle single-character internal punctuation marks
404
+ internal_punct_map = str.maketrans(
405
+ {";": "", ";": ",", ":": ",", ":": ",", "、": ","}
406
+ )
 
 
 
 
407
  content = content.translate(internal_punct_map)
408
  content = content.strip()
409
 
410
  # Keep only the final period
411
  if len(content) > 1:
412
+ last_ch = (
413
+ ""
414
+ if content[-1] == ","
415
+ else ("." if content[-1] == "," else content[-1])
416
+ )
417
+ body = content[:-1].replace("。", ",")
418
  content = body + last_ch
419
 
420
+ processed_parts.append({"tag": tag, "content": content})
421
 
422
+ if not processed_parts:
423
+ return ""
424
 
425
+ # Merge consecutive same speakers
426
+ merged_lines = []
427
+ current_tag = processed_parts[0]["tag"]
428
+ current_content = [processed_parts[0]["content"]]
429
 
430
+ for part in processed_parts[1:]:
431
+ if part["tag"] == current_tag and current_tag:
432
+ current_content.append(part["content"])
433
+ else:
434
+ merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
435
+ current_tag = part["tag"]
436
+ current_content = [part["content"]]
437
+
438
+ merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
439
+
440
+ return "".join(merged_lines).replace("‘", "'").replace("’", "'")
441
+
442
+
443
+ def process_batch(
444
+ batch_items,
445
+ tokenizer,
446
+ model,
447
+ spt,
448
+ device,
449
+ system_prompt,
450
+ start_idx,
451
+ use_normalize=False,
452
+ silence_duration=0,
453
+ ):
454
  """Process a batch of data items and generate audio, return audio data and metadata"""
455
  try:
456
  # Prepare batch data
 
459
  prompts = [system_prompt] * batch_size
460
  prompt_audios = []
461
  actual_texts_data = [] # Store actual text data used
462
+
463
  print(f"Processing {batch_size} samples starting from index {start_idx}...")
464
+
465
  # Extract text and audio from each sample
466
  for i, item in enumerate(batch_items):
467
  # Use new processing function
468
  processed_item = process_jsonl_item(item)
469
+
470
  text = processed_item["text"]
471
  prompt_text = processed_item["prompt_text"]
472
+
473
+ # Merge text, if prompt_text is empty, full_text is just text
474
+ full_text = prompt_text + text if prompt_text else text
475
  original_full_text = full_text # Save original text
476
+
477
  # Apply text normalization based on parameter
478
  if use_normalize:
479
  full_text = normalize_text(full_text)
480
+
481
  # Replace speaker tags
482
+ final_text = full_text.replace("[S1]", "<speaker1>").replace(
483
+ "[S2]", "<speaker2>"
484
+ )
485
  texts.append(final_text)
486
+
487
  # Save actual text information used
488
+ actual_texts_data.append(
489
+ {
490
+ "index": start_idx + i,
491
+ "original_text": original_full_text,
492
+ "normalized_text": (
493
+ normalize_text(original_full_text) if use_normalize else None
494
+ ),
495
+ "final_text": final_text,
496
+ "use_normalize": use_normalize,
497
+ }
498
+ )
499
+
500
  # Get reference audio
501
  prompt_audios.append(processed_item["prompt_audio"])
502
+
503
  # Process inputs
504
  input_ids_list = []
505
+ for i, (text, prompt, audio_path) in enumerate(
506
+ zip(texts, prompts, prompt_audios)
507
+ ):
508
  # Load audio data here
509
  audio_data = load_audio_data(audio_path) if audio_path else None
510
+ inputs = process_inputs(
511
+ tokenizer, spt, prompt, text, device, silence_duration, audio_data
512
+ )
513
  inputs = shifting_inputs(inputs, tokenizer)
514
  input_ids_list.append(inputs)
515
+
516
  # Pad batch inputs
517
  input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer)
518
+
519
  # Batch generation
520
  print(f"Starting batch audio generation...")
521
  start = input_ids.shape[1] - MAX_CHANNELS + 1
522
+
523
  # Move inputs to GPU
524
  input_ids = input_ids.to(device)
525
  attention_mask = attention_mask.to(device)
526
+
527
  # Generate model outputs
528
  outputs = model.generate(
529
+ input_ids=input_ids,
530
  attention_mask=attention_mask,
531
  )
532
  print(f"Original outputs shape: {outputs.shape}")
 
538
  outputs = outputs[:, start:]
539
  seq_len = outputs.shape[1] - MAX_CHANNELS + 1
540
  speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device)
541
+
 
542
  # Adjust output format
543
  for j in range(MAX_CHANNELS):
544
  speech_ids[..., j] = outputs[:, j : seq_len + j, j]
545
+ if j == 0:
546
  speech_ids[..., j] = speech_ids[..., j] - 151665
547
+
548
  # Find valid positions for each sample
549
  li = find_max_valid_positions(speech_ids)
550
+
551
  # Store audio result data
552
  audio_results = []
553
+
554
  # Process batch sample results individually
555
  for i in range(batch_size):
556
  try:
 
560
  print(f"Sample {start_idx + i} has no valid speech tokens")
561
  audio_results.append(None)
562
  continue
563
+
564
  this_speech_id = speech_ids[i, :end_idx]
565
+ print(
566
+ f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}"
567
+ )
568
+
569
+ # Prompt-Augmented Decode (rvq8-style); fall back to original decode if no prompt
570
+ prompt_audio = prompt_audios[i]
571
+ if prompt_audio is None:
572
+ # Fallback to original decode
573
+ with torch.no_grad():
574
+ codes_list = [this_speech_id.permute(1, 0)]
575
+ decode_result = spt.decode(codes_list, overlap_seconds=10)
576
+ audio_out = decode_result["syn_wav_list"][0].cpu().detach()
577
+ if audio_out.ndim == 1:
578
+ audio_out = audio_out.unsqueeze(0)
579
+ audio_results.append(
580
+ {
581
+ "audio_data": audio_out,
582
+ "sample_rate": spt.output_sample_rate,
583
+ "index": start_idx + i,
584
+ }
585
+ )
586
+ print(f"Audio generation completed (orig): sample {start_idx + i}")
587
+ else:
588
+ # 1) Load prompt at SPT input sr and force to 20s
589
+ ref_sr_in = (
590
+ getattr(spt, "input_sample_rate", None)
591
+ or getattr(spt, "sampling_rate", None)
592
+ or 24000
593
+ )
594
+ ref_wav = load_audio_data(
595
+ prompt_audio, target_sample_rate=ref_sr_in
596
+ )
597
+ if ref_wav is None:
598
+ # If ref missing, use original decode
599
+ with torch.no_grad():
600
+ codes_list = [this_speech_id.permute(1, 0)]
601
+ decode_result = spt.decode(codes_list, overlap_seconds=10)
602
+ audio_out = decode_result["syn_wav_list"][0].cpu().detach()
603
+ if audio_out.ndim == 1:
604
+ audio_out = audio_out.unsqueeze(0)
605
+ audio_results.append(
606
+ {
607
+ "audio_data": audio_out,
608
+ "sample_rate": spt.output_sample_rate,
609
+ "index": start_idx + i,
610
+ }
611
+ )
612
+ print(
613
+ f"Audio generation completed (orig no-ref): sample {start_idx + i}"
614
+ )
615
+ else:
616
+ # Encode 20s reference to tokens
617
+ ref_wav_20s = pad_or_truncate_to_seconds(
618
+ ref_wav, 20.0, ref_sr_in
619
+ ).to(device)
620
+ with torch.no_grad():
621
+ enc = spt.encode([ref_wav_20s.squeeze(0)])
622
+ ref_codes = (
623
+ enc["codes_list"][0].to(device).long()
624
+ ) # (nq, T_ref)
625
+
626
+ # Prepare token-to-sample mapping and windowing params
627
+ out_sr = (
628
+ getattr(spt, "output_sample_rate", None)
629
+ or getattr(spt, "sample_rate", None)
630
+ or 24000
631
+ )
632
+ tokens_per_second = float(ref_sr_in) / float(
633
+ spt.encoder_downsample_rate
634
+ )
635
+ tokens_per_chunk = int(round(10.0 * tokens_per_second))
636
+ stride_tokens = 85
637
+ keep_tokens = 85
638
+ left_ctx_tokens = 20
639
+ total_tokens = this_speech_id.shape[0]
640
+ samples_per_token = int(round(out_sr / tokens_per_second))
641
+ crossfade_seconds = 0.1
642
+ crossfade_samples = int(round(crossfade_seconds * out_sr))
643
+
644
+ kept_segments = []
645
+ chunk_idx = 0
646
+ while True:
647
+ st_tok = chunk_idx * stride_tokens
648
+ if st_tok >= total_tokens:
649
+ break
650
+ ed_tok = min(st_tok + tokens_per_chunk, total_tokens)
651
+ gen_chunk = this_speech_id[st_tok:ed_tok] # (len, C)
652
+ if gen_chunk.shape[0] == 0:
653
+ break
654
+
655
+ # Concatenate reference tokens with current window tokens
656
+ combined_codes = torch.cat(
657
+ [ref_codes, gen_chunk.permute(1, 0).long()], dim=1
658
+ ).to(
659
+ device
660
+ ) # (nq, T_ref + T_chunk)
661
+ codes_lengths = torch.tensor(
662
+ [combined_codes.shape[-1]],
663
+ dtype=torch.long,
664
+ device=device,
665
+ )
666
+ combined_codes_batched = combined_codes.unsqueeze(
667
+ 1
668
+ ) # (nq, 1, T)
669
+
670
+ with torch.no_grad():
671
+ detok = spt.inference_detokenize(
672
+ combined_codes_batched, codes_lengths
673
+ )
674
+ y = detok["y"][0, 0] # (T_samples)
675
+
676
+ # Remove 20s reference portion (in samples)
677
+ ref_samples = int(round(20.0 * out_sr))
678
+ if y.shape[-1] <= ref_samples:
679
+ chunk_idx += 1
680
+ continue
681
+ chunk_y = y[ref_samples:]
682
+
683
+ # Determine kept region within current window
684
+ window_len = gen_chunk.shape[0]
685
+ remains = total_tokens - st_tok
686
+ is_first = chunk_idx == 0
687
+ is_last = ed_tok >= total_tokens
688
+
689
+ if is_first:
690
+ keep_start_tok = 0
691
+ keep_end_tok = min(
692
+ keep_tokens + left_ctx_tokens, window_len
693
+ )
694
+ elif is_last and remains < 105:
695
+ keep_start_tok = (
696
+ 0 if is_first else min(left_ctx_tokens, window_len)
697
+ )
698
+ keep_end_tok = window_len
699
+ else:
700
+ keep_start_tok = min(left_ctx_tokens, window_len)
701
+ keep_end_tok = min(
702
+ left_ctx_tokens + keep_tokens, window_len
703
+ )
704
+
705
+ keep_start_smps = keep_start_tok * samples_per_token
706
+ keep_end_smps = keep_end_tok * samples_per_token
707
+ left_margin = 0
708
+ right_margin = crossfade_samples if not is_last else 0
709
+ seg_start = max(0, keep_start_smps - left_margin)
710
+ seg_end = min(
711
+ chunk_y.shape[-1], keep_end_smps + right_margin
712
+ )
713
+ if seg_end > seg_start:
714
+ kept_segments.append(
715
+ chunk_y[seg_start:seg_end]
716
+ .detach()
717
+ .cpu()
718
+ .unsqueeze(0)
719
+ )
720
+
721
+ chunk_idx += 1
722
+
723
+ # Concatenate with crossfade; if empty, return tiny silence
724
+ if len(kept_segments) == 0:
725
+ audio_out = torch.zeros(1, int(0.01 * out_sr))
726
+ else:
727
+ audio_out = crossfade_concat(
728
+ kept_segments,
729
+ out_sr,
730
+ crossfade_seconds=crossfade_seconds,
731
+ )
732
+
733
+ audio_results.append(
734
+ {
735
+ "audio_data": audio_out,
736
+ "sample_rate": out_sr,
737
+ "index": start_idx + i,
738
+ }
739
+ )
740
+ print(
741
+ f"Audio generation completed (prompt-aug): sample {start_idx + i}"
742
+ )
743
+
744
  except Exception as e:
745
  print(f"Error processing sample {start_idx + i}: {str(e)}, skipping...")
746
  import traceback
747
+
748
  traceback.print_exc()
749
  audio_results.append(None)
750
+
751
  # Clean up GPU memory
752
  torch.cuda.empty_cache()
753
+
754
  # Return text data and audio data
755
  return actual_texts_data, audio_results
756
+
757
  except Exception as e:
758
  print(f"Error during batch processing: {str(e)}")
759
+ raise