Ubuntu commited on
Commit
b6ec358
·
1 Parent(s): 24d0b1d

update llm training

Browse files
speech/config.yaml CHANGED
@@ -187,23 +187,15 @@ feat_extractor: !name:matcha.utils.audio.mel_spectrogram
187
  fmin: 0
188
  fmax: 8000
189
  center: False
190
- compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
191
- feat_extractor: !ref <feat_extractor>
192
- token_mel_ratio: !ref <token_mel_ratio>
193
- compute_f0: !name:cosyvoice.dataset.processor.compute_f0
194
- sample_rate: !ref <sample_rate>
195
- hop_size: 480
196
- parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
197
- normalize: True
198
  shuffle: !name:cosyvoice.dataset.processor.shuffle
199
  shuffle_size: 1000
200
  sort: !name:cosyvoice.dataset.processor.sort
201
  sort_size: 500 # sort_size should be less than shuffle_size
202
  batch: !name:cosyvoice.dataset.processor.batch
203
  batch_type: 'dynamic'
204
- max_frames_in_batch: 5000
205
  padding: !name:cosyvoice.dataset.processor.padding
206
- use_spk_embedding: False # change to True during sft
207
  use_speaker_encoder: !ref <use_speaker_encoder>
208
 
209
 
@@ -213,9 +205,7 @@ data_pipeline: [
213
  !ref <tokenize>,
214
  !ref <filter>,
215
  !ref <resample>,
216
- !ref <compute_fbank>,
217
  !ref <extract_reference_mel>, # Add this for speaker encoder
218
- !ref <parse_embedding>,
219
  !ref <shuffle>,
220
  !ref <sort>,
221
  !ref <batch>,
 
187
  fmin: 0
188
  fmax: 8000
189
  center: False
190
+
 
 
 
 
 
 
 
191
  shuffle: !name:cosyvoice.dataset.processor.shuffle
192
  shuffle_size: 1000
193
  sort: !name:cosyvoice.dataset.processor.sort
194
  sort_size: 500 # sort_size should be less than shuffle_size
195
  batch: !name:cosyvoice.dataset.processor.batch
196
  batch_type: 'dynamic'
197
+ max_frames_in_batch: 25000
198
  padding: !name:cosyvoice.dataset.processor.padding
 
199
  use_speaker_encoder: !ref <use_speaker_encoder>
200
 
201
 
 
205
  !ref <tokenize>,
206
  !ref <filter>,
207
  !ref <resample>,
 
208
  !ref <extract_reference_mel>, # Add this for speaker encoder
 
209
  !ref <shuffle>,
210
  !ref <sort>,
211
  !ref <batch>,
speech/cosyvoice/dataset/processor.py CHANGED
@@ -24,17 +24,15 @@ import pyworld as pw
24
  import glob
25
  import os
26
  import json
27
-
28
  AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
29
 
30
 
31
- def individual_file_opener(data, mode='train', tts_data={}):
32
- """Load data from individual files instead of parquet
33
 
34
  Args:
35
- data: Iterable[{src}] where src is either:
36
- - Path to a directory containing audio files
37
- - Path to a JSON index file
38
  mode: 'train' or 'test'
39
  tts_data: Dict for TTS mode
40
 
@@ -45,51 +43,93 @@ def individual_file_opener(data, mode='train', tts_data={}):
45
  assert 'src' in sample
46
  src = sample['src']
47
 
48
- # Determine if src is a directory or index file
49
- if src.endswith('.json'):
50
- # Load from index file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  with open(src, 'r') as f:
52
  index_data = json.load(f)
53
  file_list = index_data.get('data', [])
 
54
  else:
55
- # Scan directory for wav files
56
  wav_files = glob.glob(os.path.join(src, '*/*/*wav'))
57
  if not wav_files:
58
- # Try different patterns
59
  wav_files = glob.glob(os.path.join(src, '**/*.wav'), recursive=True)
60
 
61
- file_list = []
62
  for wav_path in wav_files:
63
- # Check if all required files exist
64
- txt_path = wav_path.replace('.wav', '.normalized.txt')
65
- embedding_path = wav_path.replace('.wav', '_embedding.pt')
66
- token_path = wav_path.replace('.wav', '_tokens.pt')
67
 
68
  if not os.path.exists(txt_path):
69
  logging.warning(f'Text file not found for {wav_path}, skipping')
70
  continue
71
 
72
- # Extract metadata
73
  utt = os.path.basename(wav_path).replace('.wav', '')
74
  spk = utt.split('_')[0]
75
 
76
- # Find speaker embedding
77
- spk_embed_dir = os.path.join(os.path.dirname(src), 'spk_embeddings')
78
- if not os.path.exists(spk_embed_dir):
79
- spk_embed_dir = os.path.join(src, 'spk_embeddings')
80
- spk_embedding_path = os.path.join(spk_embed_dir, f'{spk}_embedding.pt')
81
-
82
  file_info = {
83
  'utt': utt,
84
  'spk': spk,
85
  'wav': wav_path,
86
  'text_path': txt_path,
87
- 'embedding_path': embedding_path,
88
  'token_path': token_path,
89
- 'spk_embedding_path': spk_embedding_path
90
  }
91
  file_list.append(file_info)
92
 
 
 
93
  # Process each file
94
  for file_info in file_list:
95
  try:
@@ -98,38 +138,24 @@ def individual_file_opener(data, mode='train', tts_data={}):
98
  audio_data = f.read()
99
 
100
  # Read text
101
- with open(file_info['text_path'], 'r') as f:
102
  text = ''.join(l.strip() for l in f.readlines())
103
 
104
- # Load embeddings if they exist
105
- if os.path.exists(file_info['embedding_path']):
106
- utt_embedding = torch.load(file_info['embedding_path'], weights_only=False)
107
- if isinstance(utt_embedding, torch.Tensor):
108
- utt_embedding = utt_embedding.tolist()
109
- else:
110
- logging.warning(f"Utterance embedding not found: {file_info['embedding_path']}")
111
- # Create a dummy embedding
112
- utt_embedding = [0.0] * 192 # Assuming 192-dim embeddings
113
-
114
- # Load tokens if they exist
115
- if os.path.exists(file_info['token_path']):
116
- speech_token = torch.load(file_info['token_path'], weights_only=False)
117
- if isinstance(speech_token, torch.Tensor):
118
- speech_token = speech_token.tolist()
119
- else:
120
- logging.warning(f"Token file not found: {file_info['token_path']}")
121
- speech_token = []
122
-
123
- # Load speaker embedding
124
- if os.path.exists(file_info['spk_embedding_path']):
125
- spk_embedding = torch.load(file_info['spk_embedding_path'], weights_only=False)
126
- if isinstance(spk_embedding, torch.Tensor):
127
- spk_embedding = spk_embedding.tolist()
128
- else:
129
- logging.warning(f"Speaker embedding not found: {file_info['spk_embedding_path']}")
130
- # Use utterance embedding as fallback
131
- spk_embedding = utt_embedding
132
-
133
  # Build sample dict
134
  sample_dict = {
135
  'utt': file_info['utt'],
@@ -137,10 +163,9 @@ def individual_file_opener(data, mode='train', tts_data={}):
137
  'audio_data': audio_data,
138
  'text': text,
139
  'text_token': [], # Will be filled by tokenize processor
140
- 'utt_embedding': utt_embedding,
141
- 'spk_embedding': spk_embedding,
142
  'speech_token': speech_token,
143
  'wav': file_info['wav'], # Keep original path for reference
 
144
  }
145
 
146
  # Copy over any additional fields from the original sample
@@ -237,8 +262,10 @@ def filter(data,
237
  continue
238
  if num_frames != 0:
239
  if len(sample['text_token']) / num_frames < min_output_input_ratio:
 
240
  continue
241
  if len(sample['text_token']) / num_frames > max_output_input_ratio:
 
242
  continue
243
  yield sample
244
 
@@ -261,6 +288,7 @@ def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
261
  waveform = sample['speech']
262
  if sample_rate != resample_rate:
263
  if sample_rate < min_sample_rate:
 
264
  continue
265
  sample['sample_rate'] = resample_rate
266
  sample['speech'] = torchaudio.transforms.Resample(
@@ -292,43 +320,6 @@ def truncate(data, truncate_length=24576, mode='train'):
292
  yield sample
293
 
294
 
295
- def compute_fbank(data,
296
- feat_extractor,
297
- token_mel_ratio=0,
298
- mode='train'):
299
- """ Extract fbank
300
-
301
- Args:
302
- data: Iterable[{key, wav, label, sample_rate}]
303
-
304
- Returns:
305
- Iterable[{key, feat, label}]
306
- """
307
- for sample in data:
308
- assert 'sample_rate' in sample
309
- assert 'speech' in sample
310
- assert 'utt' in sample
311
- assert 'text_token' in sample
312
- waveform = sample['speech']
313
- feat = feat_extractor(waveform).squeeze(dim=0).transpose(0, 1)
314
- if token_mel_ratio != 0:
315
-
316
- if isinstance(sample["speech_token"], list):
317
- speech_token_tensor = torch.tensor(sample["speech_token"])
318
- else:
319
- speech_token_tensor = sample["speech_token"]
320
-
321
- # trim to align speech_token and speech_feat
322
- token_len = int(min(feat.shape[0] / token_mel_ratio, speech_token_tensor.shape[0]))
323
- feat = feat[:token_mel_ratio * token_len]
324
-
325
- # Update speech_token - keep as tensor for consistency
326
- sample["speech_token"] = speech_token_tensor[:token_len]
327
-
328
- sample['speech_feat'] = feat
329
- yield sample
330
-
331
-
332
  def extract_reference_mel_from_speech(
333
  data,
334
  feat_extractor,
@@ -361,6 +352,7 @@ def extract_reference_mel_from_speech(
361
  sample['reference_mels'] = []
362
  sample['reference_mel_lengths'] = []
363
  sample['num_references'] = 0
 
364
  yield sample
365
  continue
366
 
@@ -403,48 +395,6 @@ def extract_reference_mel_from_speech(
403
 
404
  yield sample
405
 
406
- def compute_f0(data, sample_rate, hop_size, mode='train'):
407
- """ Extract f0
408
-
409
- Args:
410
- data: Iterable[{key, wav, label, sample_rate}]
411
-
412
- Returns:
413
- Iterable[{key, feat, label}]
414
- """
415
- frame_period = hop_size * 1000 / sample_rate
416
- for sample in data:
417
- assert 'sample_rate' in sample
418
- assert 'speech' in sample
419
- assert 'utt' in sample
420
- assert 'text_token' in sample
421
- waveform = sample['speech']
422
- _f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
423
- if sum(_f0 != 0) < 5: # this happens when the algorithm fails
424
- _f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
425
- f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
426
- f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
427
- sample['pitch_feat'] = f0
428
- yield sample
429
-
430
-
431
- def parse_embedding(data, normalize, mode='train'):
432
- """ Parse utt_embedding/spk_embedding
433
-
434
- Args:
435
- data: Iterable[{key, wav, label, sample_rate}]
436
-
437
- Returns:
438
- Iterable[{key, feat, label}]
439
- """
440
- for sample in data:
441
- sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
442
- sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
443
- if normalize:
444
- sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
445
- sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
446
- yield sample
447
-
448
 
449
  def tokenize(data, get_tokenizer, allowed_special, mode='train'):
450
  """ Decode text to chars or BPE
@@ -505,12 +455,12 @@ def sort(data, sort_size=500, mode='train'):
505
  for sample in data:
506
  buf.append(sample)
507
  if len(buf) >= sort_size:
508
- buf.sort(key=lambda x: x['speech_feat'].size(0))
509
  for x in buf:
510
  yield x
511
  buf = []
512
  # The sample left over
513
- buf.sort(key=lambda x: x['speech_feat'].size(0))
514
  for x in buf:
515
  yield x
516
 
@@ -549,9 +499,9 @@ def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
549
  buf = []
550
  longest_frames = 0
551
  for sample in data:
552
- assert 'speech_feat' in sample
553
- assert isinstance(sample['speech_feat'], torch.Tensor)
554
- new_sample_frames = sample['speech_feat'].size(0)
555
  longest_frames = max(longest_frames, new_sample_frames)
556
  frames_after_padding = longest_frames * (len(buf) + 1)
557
  if frames_after_padding > max_frames_in_batch:
@@ -574,7 +524,7 @@ def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, m
574
  else:
575
  logging.fatal('Unsupported batch type {}'.format(batch_type))
576
 
577
- def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_speaker_encoder=False):
578
  """ Padding the data into training data
579
 
580
  Args:
@@ -586,9 +536,9 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
586
  """
587
  for sample in data:
588
  assert isinstance(sample, list)
589
- speech_feat_len = torch.tensor([x['speech_feat'].size(0) for x in sample], # Changed from size(1) to size(0)
590
  dtype=torch.int32)
591
- order = torch.argsort(speech_feat_len, descending=True)
592
 
593
  utts = [sample[i]['utt'] for i in order]
594
  speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
@@ -607,17 +557,16 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
607
  batch_first=True,
608
  padding_value=0)
609
 
610
- speech_feat = [sample[i]['speech_feat'] for i in order]
611
- speech_feat_len = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
612
- speech_feat = pad_sequence(speech_feat,
613
  batch_first=True,
614
  padding_value=0)
 
615
  text = [sample[i]['text'] for i in order]
616
  text_token = [torch.tensor(sample[i]['text_token']) for i in order]
617
  text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
618
  text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
619
- utt_embedding = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
620
- spk_embedding = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
621
 
622
  batch = {
623
  "utts": utts,
@@ -625,13 +574,11 @@ def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False, use_spe
625
  "speech_len": speech_len,
626
  "speech_token": speech_token,
627
  "speech_token_len": speech_token_len,
628
- "speech_feat": speech_feat,
629
- "speech_feat_len": speech_feat_len,
630
  "text": text,
631
  "text_token": text_token,
632
  "text_token_len": text_token_len,
633
- "utt_embedding": utt_embedding,
634
- "spk_embedding": spk_embedding,
635
  }
636
 
637
  # Handle reference mels for speaker encoder
 
24
  import glob
25
  import os
26
  import json
27
+ import traceback
28
  AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
29
 
30
 
31
+ def individual_file_opener(data, mode='train', tts_data={}, token_latent_ratio=3):
32
+ """Load data from individual files listed in files.txt
33
 
34
  Args:
35
+ data: Iterable[{src}] where src is path to files.txt containing audio paths
 
 
36
  mode: 'train' or 'test'
37
  tts_data: Dict for TTS mode
38
 
 
43
  assert 'src' in sample
44
  src = sample['src']
45
 
46
+ # Load file list from files.txt
47
+ file_list = []
48
+
49
+ # Check if src is a files.txt file
50
+ if src.endswith('.txt'):
51
+ with open(src, 'r') as f:
52
+ wav_files = [line.strip() for line in f if line.strip()]
53
+
54
+ for wav_path in wav_files:
55
+ # Skip empty lines or comments
56
+ if not wav_path or wav_path.startswith('#'):
57
+ continue
58
+
59
+ # Verify wav file exists
60
+ if not os.path.exists(wav_path):
61
+ logging.warning(f'Audio file not found: {wav_path}, skipping')
62
+ continue
63
+
64
+ # Check if all required files exist
65
+ txt_path = wav_path.replace('.wav', '.txt')
66
+ token_path = wav_path.replace('.wav', '_fsq.pt')
67
+ latent_path = wav_path.replace('.wav', '_latent.pt')
68
+
69
+ if not os.path.exists(txt_path):
70
+ logging.warning(f'Text file not found for {wav_path}, skipping')
71
+ continue
72
+
73
+ if not os.path.exists(token_path):
74
+ logging.warning(f'Token file not found for {wav_path}, skipping')
75
+ continue
76
+
77
+ if not os.path.exists(latent_path):
78
+ logging.warning(f'Latent file not found for {wav_path}, skipping')
79
+ continue
80
+
81
+ # Extract metadata
82
+ utt = os.path.basename(wav_path).replace('.wav', '')
83
+ # Try to extract speaker from filename (assuming format: spk_*.wav)
84
+ spk = utt.split('_')[0] if '_' in utt else 'default'
85
+
86
+ file_info = {
87
+ 'utt': utt,
88
+ 'spk': spk,
89
+ 'wav': wav_path,
90
+ 'text_path': txt_path,
91
+ 'token_path': token_path,
92
+ 'latent_path': latent_path,
93
+ }
94
+ logging.info(f'file_info {file_info}')
95
+ file_list.append(file_info)
96
+
97
+ elif src.endswith('.json'):
98
+ # Keep backward compatibility with JSON index files
99
  with open(src, 'r') as f:
100
  index_data = json.load(f)
101
  file_list = index_data.get('data', [])
102
+
103
  else:
104
+ # Assume it's a directory for backward compatibility
105
  wav_files = glob.glob(os.path.join(src, '*/*/*wav'))
106
  if not wav_files:
 
107
  wav_files = glob.glob(os.path.join(src, '**/*.wav'), recursive=True)
108
 
 
109
  for wav_path in wav_files:
110
+ txt_path = wav_path.replace('.wav', '.txt')
111
+ token_path = wav_path.replace('.wav', '_fsq.pt')
112
+ latent_path = wav_path.replace('.wav', '_latent.pt')
 
113
 
114
  if not os.path.exists(txt_path):
115
  logging.warning(f'Text file not found for {wav_path}, skipping')
116
  continue
117
 
 
118
  utt = os.path.basename(wav_path).replace('.wav', '')
119
  spk = utt.split('_')[0]
120
 
 
 
 
 
 
 
121
  file_info = {
122
  'utt': utt,
123
  'spk': spk,
124
  'wav': wav_path,
125
  'text_path': txt_path,
 
126
  'token_path': token_path,
127
+ 'latent_path': latent_path,
128
  }
129
  file_list.append(file_info)
130
 
131
+ logging.info(f'Found {len(file_list)} valid audio files from {src}')
132
+
133
  # Process each file
134
  for file_info in file_list:
135
  try:
 
138
  audio_data = f.read()
139
 
140
  # Read text
141
+ with open(file_info['text_path'], 'r', encoding='utf-8') as f:
142
  text = ''.join(l.strip() for l in f.readlines())
143
 
144
+ # Load speech token
145
+ speech_token = torch.load(file_info['token_path'], map_location='cpu', weights_only=False)
146
+ if isinstance(speech_token, torch.Tensor):
147
+ speech_token = speech_token.tolist()
148
+
149
+ # Load speech latent
150
+ speech_latent = torch.load(file_info['latent_path'], map_location='cpu', weights_only=False)
151
+ speech_latent = speech_latent['z'].transpose(0, 1)
152
+
153
+ if token_latent_ratio != 0:
154
+ # trim to align speech_token and speech_feat
155
+ token_len = int(min(speech_latent.shape[0] / token_latent_ratio, len(speech_token)))
156
+ speech_latent = speech_latent[:token_latent_ratio * token_len]
157
+ speech_token = speech_token[:token_len]
158
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  # Build sample dict
160
  sample_dict = {
161
  'utt': file_info['utt'],
 
163
  'audio_data': audio_data,
164
  'text': text,
165
  'text_token': [], # Will be filled by tokenize processor
 
 
166
  'speech_token': speech_token,
167
  'wav': file_info['wav'], # Keep original path for reference
168
+ 'speech_latent': speech_latent,
169
  }
170
 
171
  # Copy over any additional fields from the original sample
 
262
  continue
263
  if num_frames != 0:
264
  if len(sample['text_token']) / num_frames < min_output_input_ratio:
265
+ print('continue text_token')
266
  continue
267
  if len(sample['text_token']) / num_frames > max_output_input_ratio:
268
+ print('continue text_token')
269
  continue
270
  yield sample
271
 
 
288
  waveform = sample['speech']
289
  if sample_rate != resample_rate:
290
  if sample_rate < min_sample_rate:
291
+ print('continue sample_rate')
292
  continue
293
  sample['sample_rate'] = resample_rate
294
  sample['speech'] = torchaudio.transforms.Resample(
 
320
  yield sample
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def extract_reference_mel_from_speech(
324
  data,
325
  feat_extractor,
 
352
  sample['reference_mels'] = []
353
  sample['reference_mel_lengths'] = []
354
  sample['num_references'] = 0
355
+ print('continue num_references')
356
  yield sample
357
  continue
358
 
 
395
 
396
  yield sample
397
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  def tokenize(data, get_tokenizer, allowed_special, mode='train'):
400
  """ Decode text to chars or BPE
 
455
  for sample in data:
456
  buf.append(sample)
457
  if len(buf) >= sort_size:
458
+ buf.sort(key=lambda x: x['speech_latent'].size(0))
459
  for x in buf:
460
  yield x
461
  buf = []
462
  # The sample left over
463
+ buf.sort(key=lambda x: x['speech_latent'].size(0))
464
  for x in buf:
465
  yield x
466
 
 
499
  buf = []
500
  longest_frames = 0
501
  for sample in data:
502
+ assert 'speech_latent' in sample
503
+ assert isinstance(sample['speech_latent'], torch.Tensor)
504
+ new_sample_frames = sample['speech_latent'].size(0)
505
  longest_frames = max(longest_frames, new_sample_frames)
506
  frames_after_padding = longest_frames * (len(buf) + 1)
507
  if frames_after_padding > max_frames_in_batch:
 
524
  else:
525
  logging.fatal('Unsupported batch type {}'.format(batch_type))
526
 
527
+ def padding(data, mode='train', gan=False, dpo=False, use_speaker_encoder=False):
528
  """ Padding the data into training data
529
 
530
  Args:
 
536
  """
537
  for sample in data:
538
  assert isinstance(sample, list)
539
+ speech_latent_len = torch.tensor([x['speech_latent'].size(0) for x in sample], # Changed from size(1) to size(0)
540
  dtype=torch.int32)
541
+ order = torch.argsort(speech_latent_len, descending=True)
542
 
543
  utts = [sample[i]['utt'] for i in order]
544
  speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
 
557
  batch_first=True,
558
  padding_value=0)
559
 
560
+ speech_latent = [sample[i]['speech_latent'] for i in order]
561
+
562
+ speech_latent = pad_sequence(speech_latent,
563
  batch_first=True,
564
  padding_value=0)
565
+
566
  text = [sample[i]['text'] for i in order]
567
  text_token = [torch.tensor(sample[i]['text_token']) for i in order]
568
  text_token_len = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
569
  text_token = pad_sequence(text_token, batch_first=True, padding_value=0)
 
 
570
 
571
  batch = {
572
  "utts": utts,
 
574
  "speech_len": speech_len,
575
  "speech_token": speech_token,
576
  "speech_token_len": speech_token_len,
577
+ "speech_latent": speech_latent,
578
+ "speech_latent_len": speech_latent,
579
  "text": text,
580
  "text_token": text_token,
581
  "text_token_len": text_token_len,
 
 
582
  }
583
 
584
  # Handle reference mels for speaker encoder
speech/cosyvoice/utils/executor.py CHANGED
@@ -78,6 +78,10 @@ class Executor:
78
  info_dict["epoch"] = self.epoch
79
  info_dict["batch_idx"] = batch_idx
80
 
 
 
 
 
81
 
82
  if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
83
  context = model.no_sync
@@ -86,6 +90,7 @@ class Executor:
86
 
87
 
88
  with context():
 
89
  info_dict = batch_forward(
90
  model,
91
  batch_dict,
@@ -94,12 +99,13 @@ class Executor:
94
  ref_model=self.ref_model,
95
  dpo_loss=self.dpo_loss,
96
  )
97
-
98
  info_dict = batch_backward(model, scaler, info_dict)
99
-
100
  info_dict = update_parameter_and_lr(
101
  model, optimizer, scheduler, scaler, info_dict, model_type=model_type
102
  )
 
103
  log_per_step(experiment, info_dict)
104
 
105
  if (
 
78
  info_dict["epoch"] = self.epoch
79
  info_dict["batch_idx"] = batch_idx
80
 
81
+ for key, value in batch_dict.items():
82
+ if isinstance(value, torch.Tensor):
83
+ print(f'{key} {value.shape}')
84
+
85
 
86
  if use_ddp and (batch_idx + 1) % info_dict["accum_grad"] != 0:
87
  context = model.no_sync
 
90
 
91
 
92
  with context():
93
+ logger.info(f'{self.rank} batch_forward')
94
  info_dict = batch_forward(
95
  model,
96
  batch_dict,
 
99
  ref_model=self.ref_model,
100
  dpo_loss=self.dpo_loss,
101
  )
102
+ logger.info(f'{self.rank} batch_backward')
103
  info_dict = batch_backward(model, scaler, info_dict)
104
+ logger.info(f'{self.rank} update_parameter_and_lr')
105
  info_dict = update_parameter_and_lr(
106
  model, optimizer, scheduler, scaler, info_dict, model_type=model_type
107
  )
108
+ logger.info(f'{self.rank} log_per_step')
109
  log_per_step(experiment, info_dict)
110
 
111
  if (
speech/cosyvoice/utils/train_utils.py CHANGED
@@ -312,7 +312,7 @@ def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None
312
 
313
  with autocast:
314
  info_dict['loss_dict'] = model(batch, device)
315
- # print('infor_dict loss_dict : ', info_dict['loss_dict'])
316
  if ref_model is not None and dpo_loss is not None:
317
  chosen_logps = info_dict['loss_dict']["chosen_logps"]
318
  rejected_logps = info_dict['loss_dict']["rejected_logps"]
 
312
 
313
  with autocast:
314
  info_dict['loss_dict'] = model(batch, device)
315
+ print('infor_dict loss_dict : ', info_dict['loss_dict'])
316
  if ref_model is not None and dpo_loss is not None:
317
  chosen_logps = info_dict['loss_dict']["chosen_logps"]
318
  rejected_logps = info_dict['loss_dict']["rejected_logps"]
speech/tools/generate_json_index.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Generate JSON index file for dataset
4
+ This creates a JSON file with all valid audio files and their metadata
5
+ """
6
+
7
+ import os
8
+ import json
9
+ import glob
10
+ import argparse
11
+ from pathlib import Path
12
+ from tqdm import tqdm
13
+ from concurrent.futures import ThreadPoolExecutor, as_completed
14
+ import hashlib
15
+ from datetime import datetime
16
+
17
+ def validate_file_set(wav_path):
18
+ """Check if all required files exist for a given wav file"""
19
+ txt_path = wav_path.replace('.wav', '.txt')
20
+ token_path = wav_path.replace('.wav', '_fsq.pt')
21
+ latent_path = wav_path.replace('.wav', '_latent.pt')
22
+
23
+ # Check all files exist
24
+ if not all(os.path.exists(p) for p in [wav_path, txt_path, token_path, latent_path]):
25
+ return None
26
+
27
+ # Get file sizes for validation
28
+ try:
29
+ wav_size = os.path.getsize(wav_path)
30
+ txt_size = os.path.getsize(txt_path)
31
+ token_size = os.path.getsize(token_path)
32
+ latent_size = os.path.getsize(latent_path)
33
+
34
+ # Skip if any file is empty
35
+ if any(size == 0 for size in [wav_size, txt_size, token_size, latent_size]):
36
+ return None
37
+
38
+ # Extract metadata
39
+ utt = os.path.basename(wav_path).replace('.wav', '')
40
+ spk = utt.split('_')[0] if '_' in utt else 'default'
41
+
42
+ # Get file modification time
43
+ mtime = os.path.getmtime(wav_path)
44
+
45
+ return {
46
+ 'utt': utt,
47
+ 'spk': spk,
48
+ 'wav': wav_path,
49
+ 'text_path': txt_path,
50
+ 'token_path': token_path,
51
+ 'latent_path': latent_path,
52
+ 'wav_size': wav_size,
53
+ 'txt_size': txt_size,
54
+ 'token_size': token_size,
55
+ 'latent_size': latent_size,
56
+ 'mtime': mtime,
57
+ }
58
+ except Exception as e:
59
+ print(f"Error processing {wav_path}: {e}")
60
+ return None
61
+
62
+ def process_directory(directory, pattern='**/*.wav'):
63
+ """Process a directory and find all valid audio files"""
64
+ print(f"Scanning directory: {directory}")
65
+ wav_files = glob.glob(os.path.join(directory, pattern), recursive=True)
66
+ print(f"Found {len(wav_files)} wav files")
67
+
68
+ valid_files = []
69
+
70
+ # Process files in parallel
71
+ with ThreadPoolExecutor(max_workers=16) as executor:
72
+ futures = [executor.submit(validate_file_set, wav_path) for wav_path in wav_files]
73
+
74
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Validating files"):
75
+ result = future.result()
76
+ if result is not None:
77
+ valid_files.append(result)
78
+
79
+ return valid_files
80
+
81
+ def process_files_txt(files_txt):
82
+ """Process files from a text file list"""
83
+ print(f"Reading file list from: {files_txt}")
84
+
85
+ with open(files_txt, 'r') as f:
86
+ wav_files = [line.strip() for line in f if line.strip() and not line.startswith('#')]
87
+
88
+ print(f"Found {len(wav_files)} files in list")
89
+
90
+ valid_files = []
91
+
92
+ # Process files in parallel
93
+ with ThreadPoolExecutor(max_workers=16) as executor:
94
+ futures = [executor.submit(validate_file_set, wav_path) for wav_path in wav_files]
95
+
96
+ for future in tqdm(as_completed(futures), total=len(futures), desc="Validating files"):
97
+ result = future.result()
98
+ if result is not None:
99
+ valid_files.append(result)
100
+
101
+ return valid_files
102
+
103
+ def generate_statistics(file_list):
104
+ """Generate statistics about the dataset"""
105
+ stats = {
106
+ 'total_files': len(file_list),
107
+ 'total_size_gb': sum(f['wav_size'] + f['txt_size'] + f['token_size'] + f['latent_size']
108
+ for f in file_list) / (1024**3),
109
+ 'speakers': {},
110
+ 'file_sizes': {
111
+ 'wav_total_gb': sum(f['wav_size'] for f in file_list) / (1024**3),
112
+ 'txt_total_mb': sum(f['txt_size'] for f in file_list) / (1024**2),
113
+ 'token_total_gb': sum(f['token_size'] for f in file_list) / (1024**3),
114
+ 'latent_total_gb': sum(f['latent_size'] for f in file_list) / (1024**3),
115
+ }
116
+ }
117
+
118
+ # Count files per speaker
119
+ for f in file_list:
120
+ spk = f['spk']
121
+ if spk not in stats['speakers']:
122
+ stats['speakers'][spk] = 0
123
+ stats['speakers'][spk] += 1
124
+
125
+ stats['num_speakers'] = len(stats['speakers'])
126
+
127
+ return stats
128
+
129
+ def generate_json_index(input_paths, output_file, split_ratio=None):
130
+ """
131
+ Generate JSON index file from input paths
132
+
133
+ Args:
134
+ input_paths: List of paths (directories or files.txt)
135
+ output_file: Output JSON file path
136
+ split_ratio: Optional tuple (train_ratio, val_ratio, test_ratio)
137
+ """
138
+ all_files = []
139
+
140
+ # Process each input path
141
+ for path in input_paths:
142
+ if os.path.isdir(path):
143
+ files = process_directory(path)
144
+ elif path.endswith('.txt'):
145
+ files = process_files_txt(path)
146
+ else:
147
+ print(f"Warning: Skipping unknown input type: {path}")
148
+ continue
149
+
150
+ all_files.extend(files)
151
+
152
+ # Remove duplicates based on utterance ID
153
+ unique_files = {}
154
+ for f in all_files:
155
+ utt = f['utt']
156
+ if utt not in unique_files:
157
+ unique_files[utt] = f
158
+ else:
159
+ # Keep the one with newer modification time
160
+ if f['mtime'] > unique_files[utt]['mtime']:
161
+ unique_files[utt] = f
162
+
163
+ file_list = list(unique_files.values())
164
+
165
+ # Sort by utterance ID for consistency
166
+ file_list.sort(key=lambda x: x['utt'])
167
+
168
+ print(f"\nTotal unique files: {len(file_list)}")
169
+
170
+ # Generate statistics
171
+ stats = generate_statistics(file_list)
172
+
173
+ # Create index
174
+ index = {
175
+ 'version': '1.0',
176
+ 'created': datetime.now().isoformat(),
177
+ 'statistics': stats,
178
+ 'data': file_list
179
+ }
180
+
181
+ # Optional: Create train/val/test splits
182
+ if split_ratio:
183
+ import random
184
+ random.seed(42) # For reproducibility
185
+
186
+ train_ratio, val_ratio, test_ratio = split_ratio
187
+ assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 0.001, "Split ratios must sum to 1"
188
+
189
+ # Shuffle for random split
190
+ shuffled = file_list.copy()
191
+ random.shuffle(shuffled)
192
+
193
+ n = len(shuffled)
194
+ train_end = int(n * train_ratio)
195
+ val_end = int(n * (train_ratio + val_ratio))
196
+
197
+ splits = {
198
+ 'train': shuffled[:train_end],
199
+ 'val': shuffled[train_end:val_end],
200
+ 'test': shuffled[val_end:]
201
+ }
202
+
203
+ # Save separate files for each split
204
+ base_name = output_file.replace('.json', '')
205
+
206
+ for split_name, split_data in splits.items():
207
+ split_index = {
208
+ 'version': '1.0',
209
+ 'created': datetime.now().isoformat(),
210
+ 'split': split_name,
211
+ 'statistics': generate_statistics(split_data),
212
+ 'data': split_data
213
+ }
214
+
215
+ split_file = f"{base_name}_{split_name}.json"
216
+ with open(split_file, 'w') as f:
217
+ json.dump(split_index, f, indent=2)
218
+
219
+ print(f"Saved {split_name} split: {split_file} ({len(split_data)} files)")
220
+
221
+ # Save main index
222
+ with open(output_file, 'w') as f:
223
+ json.dump(index, f, indent=2)
224
+
225
+ print(f"\nSaved index to: {output_file}")
226
+ print(f"Total files: {stats['total_files']}")
227
+ print(f"Total size: {stats['total_size_gb']:.2f} GB")
228
+ print(f"Number of speakers: {stats['num_speakers']}")
229
+
230
+ def main():
231
+ parser = argparse.ArgumentParser(description="Generate JSON index for dataset")
232
+ parser.add_argument('--input', nargs='+', required=True,
233
+ help='Input paths (directories or files.txt)')
234
+ parser.add_argument('--output', default='dataset_index.json',
235
+ help='Output JSON file (default: dataset_index.json)')
236
+ parser.add_argument('--split', nargs=3, type=float, metavar=('TRAIN', 'VAL', 'TEST'),
237
+ help='Create train/val/test splits (e.g., --split 0.8 0.1 0.1)')
238
+
239
+ args = parser.parse_args()
240
+
241
+ # Validate split ratios if provided
242
+ split_ratio = None
243
+ if args.split:
244
+ split_ratio = tuple(args.split)
245
+ if abs(sum(split_ratio) - 1.0) > 0.001:
246
+ parser.error("Split ratios must sum to 1.0")
247
+
248
+ generate_json_index(args.input, args.output, split_ratio)
249
+
250
+ if __name__ == "__main__":
251
+ main()
252
+
253
+ # Example usage:
254
+ # python generate_json_index.py --input /data/dataset/emilia /data/dataset/vivoice --output dataset_index.json
255
+ # python generate_json_index.py --input train_files.txt --output train_index.json
256
+ # python generate_json_index.py --input /data/dataset/emilia --output dataset_index.json --split 0.8 0.1 0.1