CaasiHUANG commited on
Commit
1bf64e7
·
1 Parent(s): f2d53f6

Aligned with tts

Browse files
Files changed (1) hide show
  1. processing_moss_tts.py +20 -25
processing_moss_tts.py CHANGED
@@ -763,7 +763,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
763
  if breaks.numel() == 0:
764
  segments_idx = [idx]
765
  else:
766
- segments_idx = torch.split(idx, breaks)
767
 
768
  audio_codes_list = [audio_codes[s] for s in segments_idx]
769
 
@@ -979,30 +979,25 @@ class MossTTSDelayProcessor(ProcessorMixin):
979
  codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
980
  for codes in audio_tokens_list
981
  ]
982
-
983
- if hasattr(audio_tokenizer, "batch_decode"):
984
- dec = audio_tokenizer.batch_decode(codes_list)
985
- audio = dec.audio # (B, C, T)
986
- audio_lengths = dec.audio_lengths # (B,)
987
- else:
988
- # Fallback: pad to (NQ, B, T) + mask, then decode.
989
- nq = int(codes_list[0].shape[0])
990
- max_t = max(int(c.shape[1]) for c in codes_list)
991
- audio_codes = torch.zeros(
992
- nq, len(codes_list), max_t, device=device, dtype=torch.long
993
- )
994
- padding_mask = torch.zeros(
995
- len(codes_list), max_t, device=device, dtype=torch.bool
996
- )
997
- for i, c in enumerate(codes_list):
998
- t = int(c.shape[1])
999
- audio_codes[:, i, :t] = c
1000
- padding_mask[i, :t] = True
1001
- dec = audio_tokenizer.decode(
1002
- audio_codes, padding_mask=padding_mask, return_dict=True
1003
- )
1004
- audio = dec.audio
1005
- audio_lengths = dec.audio_lengths
1006
 
1007
  if audio is None or audio_lengths is None:
1008
  raise RuntimeError(
 
763
  if breaks.numel() == 0:
764
  segments_idx = [idx]
765
  else:
766
+ segments_idx = torch.split(idx, breaks.tolist())
767
 
768
  audio_codes_list = [audio_codes[s] for s in segments_idx]
769
 
 
979
  codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
980
  for codes in audio_tokens_list
981
  ]
982
+
983
+ # Fallback: pad to (NQ, B, T) + mask, then decode.
984
+ nq = int(codes_list[0].shape[0])
985
+ max_t = max(int(c.shape[1]) for c in codes_list)
986
+ audio_codes = torch.zeros(
987
+ nq, len(codes_list), max_t, device=device, dtype=torch.long
988
+ )
989
+ padding_mask = torch.zeros(
990
+ len(codes_list), max_t, device=device, dtype=torch.bool
991
+ )
992
+ for i, c in enumerate(codes_list):
993
+ t = int(c.shape[1])
994
+ audio_codes[:, i, :t] = c
995
+ padding_mask[i, :t] = True
996
+ dec = audio_tokenizer.decode(
997
+ audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
998
+ )
999
+ audio = dec.audio
1000
+ audio_lengths = dec.audio_lengths
 
 
 
 
 
1001
 
1002
  if audio is None or audio_lengths is None:
1003
  raise RuntimeError(