Commit
·
1bf64e7
1
Parent(s):
f2d53f6
Aligned with tts
Browse files- processing_moss_tts.py +20 -25
processing_moss_tts.py
CHANGED
|
@@ -763,7 +763,7 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 763 |
if breaks.numel() == 0:
|
| 764 |
segments_idx = [idx]
|
| 765 |
else:
|
| 766 |
-
segments_idx = torch.split(idx, breaks)
|
| 767 |
|
| 768 |
audio_codes_list = [audio_codes[s] for s in segments_idx]
|
| 769 |
|
|
@@ -979,30 +979,25 @@ class MossTTSDelayProcessor(ProcessorMixin):
|
|
| 979 |
codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
|
| 980 |
for codes in audio_tokens_list
|
| 981 |
]
|
| 982 |
-
|
| 983 |
-
|
| 984 |
-
|
| 985 |
-
|
| 986 |
-
|
| 987 |
-
|
| 988 |
-
|
| 989 |
-
|
| 990 |
-
max_t =
|
| 991 |
-
|
| 992 |
-
|
| 993 |
-
)
|
| 994 |
-
|
| 995 |
-
|
| 996 |
-
|
| 997 |
-
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
dec = audio_tokenizer.decode(
|
| 1002 |
-
audio_codes, padding_mask=padding_mask, return_dict=True
|
| 1003 |
-
)
|
| 1004 |
-
audio = dec.audio
|
| 1005 |
-
audio_lengths = dec.audio_lengths
|
| 1006 |
|
| 1007 |
if audio is None or audio_lengths is None:
|
| 1008 |
raise RuntimeError(
|
|
|
|
| 763 |
if breaks.numel() == 0:
|
| 764 |
segments_idx = [idx]
|
| 765 |
else:
|
| 766 |
+
segments_idx = torch.split(idx, breaks.tolist())
|
| 767 |
|
| 768 |
audio_codes_list = [audio_codes[s] for s in segments_idx]
|
| 769 |
|
|
|
|
| 979 |
codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long)
|
| 980 |
for codes in audio_tokens_list
|
| 981 |
]
|
| 982 |
+
|
| 983 |
+
# Fallback: pad to (NQ, B, T) + mask, then decode.
|
| 984 |
+
nq = int(codes_list[0].shape[0])
|
| 985 |
+
max_t = max(int(c.shape[1]) for c in codes_list)
|
| 986 |
+
audio_codes = torch.zeros(
|
| 987 |
+
nq, len(codes_list), max_t, device=device, dtype=torch.long
|
| 988 |
+
)
|
| 989 |
+
padding_mask = torch.zeros(
|
| 990 |
+
len(codes_list), max_t, device=device, dtype=torch.bool
|
| 991 |
+
)
|
| 992 |
+
for i, c in enumerate(codes_list):
|
| 993 |
+
t = int(c.shape[1])
|
| 994 |
+
audio_codes[:, i, :t] = c
|
| 995 |
+
padding_mask[i, :t] = True
|
| 996 |
+
dec = audio_tokenizer.decode(
|
| 997 |
+
audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8
|
| 998 |
+
)
|
| 999 |
+
audio = dec.audio
|
| 1000 |
+
audio_lengths = dec.audio_lengths
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
|
| 1002 |
if audio is None or audio_lengths is None:
|
| 1003 |
raise RuntimeError(
|