MOSS-TTS-Realtime / streaming_mossttsrealtime.py
gaoyang07
Fix apply_repetition_penalty
078ab4e
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streaming inference utilities for MossTTSRealtime."""
from __future__ import annotations
import contextlib
import re
from typing import Iterable, Iterator, List, Optional, Sequence
import numpy as np
import torch
import torch.nn.functional as F
from transformers.cache_utils import StaticCache
from transformers.utils import is_torchaudio_available, requires_backends
from transformers.utils.import_utils import requires
if is_torchaudio_available():
import torchaudio
@requires(backends=("torch",))
class MossTTSRealtimeInference:
"""Step-wise inference wrapper for MossTTSRealtime.
This class mirrors the non-streaming inference logic but exposes a
prefill/step/finish API for streaming usage.
"""
def __init__(
self,
model,
tokenizer,
max_length: int = 1000,
channels: int = 16,
audio_channel_pad: int = 1024,
audio_bos_token: int = 1025,
audio_eos_token: int = 1026,
text_pad_id: int = 151655,
aud_pad_id: int = 151654,
):
self.model = model
self.tokenizer = tokenizer
self.max_length = max_length
self.channels = channels
self.audio_channel_pad = audio_channel_pad
self.audio_bos_token = audio_bos_token
self.audio_eos_token = audio_eos_token
self.text_pad_id = text_pad_id
self.aud_pad_id = aud_pad_id
self.past_key_values = None
self.attention_mask = None
self._generated_tokens: List[torch.Tensor] = []
self._is_stopping = None
self._last_audio_tokens = None
self._step_idx = 0
@property
def device(self):
return next(self.model.parameters()).device
@property
def is_finished(self) -> bool:
return self._is_stopping is not None and bool(self._is_stopping.all())
def reset_generation_state(self, keep_cache: bool = True):
if not keep_cache:
self.past_key_values = None
self.attention_mask = None
# Keep the mask when reusing cache so it stays aligned with past_key_values.
# This allows concatenation with the next turn prefill mask.
self._generated_tokens = []
self._is_stopping = None
self._last_audio_tokens = None
self._step_idx = 0
def _normalize_input_ids(self, input_ids):
if isinstance(input_ids, torch.Tensor):
input_ids = input_ids.detach().cpu().numpy()
if isinstance(input_ids, np.ndarray):
if input_ids.ndim == 2:
return [input_ids]
if input_ids.ndim == 3:
return [input_ids[i] for i in range(input_ids.shape[0])]
if isinstance(input_ids, (list, tuple)):
return [np.array(item) for item in input_ids]
raise ValueError("input_ids must be a list/array/tensor of shape [T, C] or [B, T, C].")
def _normalize_text_prefix(self, text_prefix_ids, batch_size: int) -> list[list[int]]:
if text_prefix_ids is None:
raise ValueError("text_prefix_ids must be provided for prefill.")
if isinstance(text_prefix_ids, torch.Tensor):
text_prefix_ids = text_prefix_ids.detach().cpu().tolist()
if isinstance(text_prefix_ids, np.ndarray):
text_prefix_ids = text_prefix_ids.tolist()
if isinstance(text_prefix_ids, list):
if len(text_prefix_ids) == 0:
return [[] for _ in range(batch_size)]
if isinstance(text_prefix_ids[0], (int, np.integer)):
return [list(text_prefix_ids)]
if len(text_prefix_ids) == 1 and batch_size > 1:
return [list(text_prefix_ids[0]) for _ in range(batch_size)]
if len(text_prefix_ids) != batch_size:
raise ValueError(
f"text_prefix_ids batch size mismatch: got {len(text_prefix_ids)}, expected {batch_size}."
)
return [list(item) for item in text_prefix_ids]
raise ValueError("text_prefix_ids must be list-like or tensor-like.")
@torch.inference_mode()
def prefill(
self,
input_ids,
text_prefix_ids,
max_prefill_len: Optional[int] = None,
past_key_values=None,
device: Optional[torch.device] = None,
temperature: float = 0.8,
top_p: float = 0.6,
top_k: int = 30,
do_sample: bool = True,
repetition_penalty: Optional[float] = 1.1,
repetition_window: Optional[int] = 50,
) -> torch.Tensor:
if device is None:
device = self.device
if past_key_values is not None:
self.past_key_values = past_key_values
input_ids_list = self._normalize_input_ids(input_ids)
batch_size = len(input_ids_list)
text_prefix_list = self._normalize_text_prefix(text_prefix_ids, batch_size)
concat_inputs_id_list = []
for i in range(batch_size):
prefix = text_prefix_list[i]
if max_prefill_len is not None:
prefix = prefix[:max_prefill_len]
if len(prefix) == 0:
raise ValueError("Prefill requires at least one text token.")
text_seg = np.full((len(prefix), self.channels + 1), self.audio_channel_pad, dtype=np.int64)
text_seg[:, 0] = np.array(prefix, dtype=np.int64)
text_seg[len(prefix) - 1, 1] = self.audio_bos_token
concat_inputs_id = np.concatenate([input_ids_list[i], text_seg], axis=0)
concat_inputs_id_list.append(concat_inputs_id)
attention_masks = [np.ones(ids.shape[0], dtype=np.bool_) for ids in concat_inputs_id_list]
max_len = max(ids.shape[0] for ids in concat_inputs_id_list)
padded_input_ids, padded_attns = [], []
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.text_pad_id
for ids, attn in zip(concat_inputs_id_list, attention_masks):
pad_len = max_len - ids.shape[0]
input_pad = np.full((pad_len, self.channels + 1), self.audio_channel_pad, dtype=np.int64)
input_pad[:, 0] = pad_token_id
padded_input_ids.append(np.concatenate([input_pad, ids]))
attn_pad = np.zeros(pad_len, dtype=np.bool_)
padded_attns.append(np.concatenate([attn_pad, attn]))
current_input_ids = torch.from_numpy(np.stack(padded_input_ids)).to(device)
current_attention_mask = torch.from_numpy(np.stack(padded_attns)).to(device)
# For multi-turn continuation, concatenate the cached mask and the current prefill mask.
if self.attention_mask is not None and self.past_key_values is not None:
current_attention_mask = torch.cat([self.attention_mask, current_attention_mask], dim=-1)
outputs = self.model(
input_ids=current_input_ids,
attention_mask=current_attention_mask,
past_key_values=self.past_key_values,
use_cache=True,
return_dict=True,
)
self.past_key_values = outputs.past_key_values
self.attention_mask = current_attention_mask
backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
audio_tokens = self.generate_local_transformer(
hidden_states=backbone_hidden_states,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
repetition_window=repetition_window,
generated_tokens=None,
gen_step=0,
)
self._generated_tokens = [audio_tokens]
self._last_audio_tokens = audio_tokens
self._is_stopping = audio_tokens[:, 0] == self.audio_eos_token
self._step_idx = 1
return audio_tokens
@torch.inference_mode()
def step(
self,
text_token: Optional[Iterable[int] | torch.Tensor | int],
temperature: float = 0.8,
top_p: float = 0.6,
top_k: int = 30,
do_sample: bool = True,
repetition_penalty: Optional[float] = 1.1,
repetition_window: Optional[int] = 50,
) -> torch.Tensor:
if self._last_audio_tokens is None or self.attention_mask is None:
raise ValueError("You must call prefill() before step().")
if self.is_finished:
return self._last_audio_tokens
batch_size = self._last_audio_tokens.shape[0]
if text_token is None:
text_tokens = [self.text_pad_id] * batch_size
elif isinstance(text_token, torch.Tensor):
text_tokens = text_token.detach().cpu().tolist()
elif isinstance(text_token, (list, tuple, np.ndarray)):
text_tokens = list(text_token)
else:
text_tokens = [int(text_token)]
if len(text_tokens) != batch_size:
raise ValueError(f"text_token batch size mismatch: got {len(text_tokens)}, expected {batch_size}.")
device = self._last_audio_tokens.device
text_t = torch.tensor(text_tokens, device=device, dtype=torch.long)
step_ids = torch.cat([text_t[:, None, None], self._last_audio_tokens.unsqueeze(1)], dim=2)
self.attention_mask = torch.cat([self.attention_mask, (~self._is_stopping).unsqueeze(-1)], dim=-1)
outputs = self.model(
input_ids=step_ids,
attention_mask=self.attention_mask,
past_key_values=self.past_key_values,
use_cache=True,
return_dict=True,
)
self.past_key_values = outputs.past_key_values
backbone_hidden_states = outputs.last_hidden_state[:, -1:, :]
history = torch.stack(self._generated_tokens, dim=1) if self._generated_tokens else None
audio_tokens = self.generate_local_transformer(
hidden_states=backbone_hidden_states,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
repetition_window=repetition_window,
generated_tokens=history,
gen_step=self._step_idx,
)
self._generated_tokens.append(audio_tokens)
self._last_audio_tokens = audio_tokens
self._is_stopping |= audio_tokens[:, 0] == self.audio_eos_token
self._step_idx += 1
return audio_tokens
@torch.inference_mode()
def finish(
self,
max_steps: Optional[int] = None,
temperature: float = 0.8,
top_p: float = 0.6,
top_k: int = 30,
do_sample: bool = True,
repetition_penalty: Optional[float] = 1.1,
repetition_window: Optional[int] = 50,
) -> list[torch.Tensor]:
outputs = []
steps_left = max_steps if max_steps is not None else self.max_length
while steps_left > 0 and not self.is_finished:
outputs.append(
self.step(
text_token=None,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
repetition_window=repetition_window,
)
)
steps_left -= 1
return outputs
@torch.compile(fullgraph=True)
def generate_local_transformer(
self,
hidden_states: torch.Tensor,
temperature: float,
top_p: float,
top_k: int,
do_sample: bool,
repetition_penalty: Optional[float],
repetition_window: Optional[int],
generated_tokens: Optional[torch.Tensor],
gen_step: int,
) -> torch.Tensor:
batch_size = hidden_states.shape[0]
device = hidden_states.device
local_inputs = hidden_states.reshape(-1, 1, self.model.config.local_config.hidden_size)
output_token = torch.empty(batch_size, self.channels, dtype=torch.long, device=device)
past_key_values = StaticCache(config=self.model.local_transformer.config, max_cache_len=self.channels)
local_token = None
cache_pos_t = torch.zeros(1, dtype=torch.long, device=device)
for i in range(self.channels):
cache_pos_t.fill_(i)
local_outputs = self.model.local_transformer(
input_ids=local_token,
inputs_embeds=local_inputs,
past_key_values=past_key_values,
cache_position=cache_pos_t,
codebook_idx=i,
use_cache=True,
logits_to_keep=1,
)
logits = local_outputs.logits
if repetition_penalty and repetition_penalty != 1.0 and generated_tokens is not None:
logits = self.apply_repetition_penalty(
scores=logits,
history_tokens=generated_tokens[:, :gen_step, i],
penalty=float(repetition_penalty),
repetition_window=repetition_window,
)
local_token = self.sample_token(
logits=logits,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
)
output_token[:, i] = local_token.squeeze(-1)
if i == 0:
local_inputs = None
return output_token
def apply_repetition_penalty(
self,
scores: torch.Tensor,
history_tokens: torch.Tensor,
penalty: float = 1.1,
repetition_window: Optional[int] = None,
):
scores_ = scores[:, 0, :]
ht = history_tokens
if repetition_window is not None and repetition_window > 0:
ht = ht[:, -repetition_window:]
cur = scores_.gather(1, ht)
new = torch.where(cur < 0, cur * penalty, cur / penalty)
scores_.scatter_(1, ht, new)
return scores_
def sample_token(self, logits, temperature, top_p=0.6, top_k=30, do_sample=True):
if not do_sample or temperature == 0:
return torch.argmax(logits, dim=-1)
logits = logits / temperature
original_shape = logits.shape
vocab_size = original_shape[-1]
reshaped_logits = logits.reshape(-1, vocab_size)
if top_k is not None:
reshaped_logits = self.apply_top_k(reshaped_logits, top_k)
if top_p is not None:
reshaped_logits = self.apply_top_p(reshaped_logits, top_p)
probs = F.softmax(reshaped_logits, dim=-1)
next_tokens_flat = torch.multinomial(probs, num_samples=1)
output_shape = original_shape[:-1]
return next_tokens_flat.view(output_shape)
def apply_top_k(self, logits, top_k, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
if not isinstance(top_k, int) or top_k <= 0:
raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
batch_size, vocab_size = logits.shape
top_k = max(top_k, min_tokens_to_keep)
top_k = min(top_k, vocab_size)
indices_to_remove = torch.topk(logits, top_k, dim=-1).values[..., -1, None]
return logits.masked_fill(logits < indices_to_remove, filter_value)
def apply_top_p(self, logits, top_p, filter_value=float("-inf"), min_tokens_to_keep: int = 1):
top_p = float(top_p)
if top_p < 0 or top_p > 1.0:
raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool).scatter(1, sorted_indices, sorted_indices_to_remove)
logits_processed = logits.masked_fill(indices_to_remove, filter_value)
return logits_processed
@requires(backends=("torch",))
class MossTTSRealtimeStreamingSession:
"""Manage text-to-audio streaming for a single conversation."""
_split_pattern = re.compile(
r"[。!?!?\.\u2026]\s*" # sentence boundaries: 。!? ! ? . …
r"|[,,;;::\u2014\u2013\-]\s*" # short pauses: , , ; ; : : — – -
r"|\)\s*|\]\s*" # closing brackets: ) ]
r"|\n"
)
def __init__(
self,
inferencer: MossTTSRealtimeInference,
processor,
codec=None,
codec_sample_rate: int = 24000,
codec_encode_kwargs: Optional[dict] = None,
prefill_text_len: int = 12,
text_buffer_size: int = 32,
min_text_chunk_chars: int = 8,
temperature: float = 0.8,
top_p: float = 0.6,
top_k: int = 30,
do_sample: bool = True,
repetition_penalty: Optional[float] = 1.1,
repetition_window: Optional[int] = 50,
):
self.inferencer = inferencer
self.processor = processor
self.tokenizer = processor.tokenizer
self.codec = codec
self.codec_sample_rate = codec_sample_rate
self.codec_encode_kwargs = codec_encode_kwargs or {}
self.prefill_text_len = prefill_text_len
self.text_buffer_size = text_buffer_size
self.min_text_chunk_chars = min_text_chunk_chars
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
self.do_sample = do_sample
self.repetition_penalty = repetition_penalty
self.repetition_window = repetition_window
self._voice_prompt_tokens = None
self._turn_input_ids = None
self._turn_idx = 0
self._text_cache = ""
self._pending_tokens: list[int] = []
self._prefilled = False
self._text_ended = False
def set_voice_prompt_tokens(self, audio_tokens: np.ndarray):
self._voice_prompt_tokens = audio_tokens
def set_voice_prompt(self, audio, sample_rate: Optional[int] = None):
"""Set voice prompt from either audio tokens or waveform.
If `audio` is a 2D array whose shape matches the codebook channels, it is
treated as audio tokens. Otherwise a codec is required to encode waveform
prompts into tokens.
"""
if isinstance(audio, np.ndarray) and audio.ndim == 2:
if self.processor.channels in audio.shape:
self._voice_prompt_tokens = audio
return
if isinstance(audio, torch.Tensor) and audio.dim() == 2:
if self.processor.channels in audio.shape:
self._voice_prompt_tokens = audio.detach().cpu().numpy()
return
if self.codec is None:
raise ValueError("codec is required to encode waveform prompts.")
waveform = audio
if isinstance(audio, (str, bytes)):
requires_backends(self, ["torchaudio"])
wav, sr = torchaudio.load(audio)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
waveform = wav.squeeze(0)
sample_rate = sr
if isinstance(waveform, np.ndarray):
waveform = torch.from_numpy(waveform)
if not isinstance(waveform, torch.Tensor):
raise ValueError("Unsupported audio type for voice prompt.")
if sample_rate is not None and sample_rate != self.codec_sample_rate:
requires_backends(self, ["torchaudio"])
waveform = torchaudio.functional.resample(waveform, sample_rate, self.codec_sample_rate)
waveform = waveform.to(self.inferencer.device)
encode_out = self.codec.encode([waveform], **self.codec_encode_kwargs)
if isinstance(encode_out, dict):
if "codes_list" in encode_out:
tokens = encode_out["codes_list"][0]
elif "audio_codes" in encode_out:
tokens = encode_out["audio_codes"][0]
else:
raise ValueError("codec.encode output missing audio codes.")
else:
tokens = encode_out
if isinstance(tokens, torch.Tensor):
tokens = tokens.detach().cpu().numpy()
self._voice_prompt_tokens = tokens
def clear_voice_prompt(self):
self._voice_prompt_tokens = None
def reset_turn(
self,
user_text: Optional[str] = None,
user_audio_tokens: Optional[np.ndarray] = None,
input_ids: Optional[np.ndarray] = None,
include_system_prompt: Optional[bool] = None,
reset_cache: bool = False,
):
if include_system_prompt is None:
include_system_prompt = self._turn_idx == 0
if input_ids is None:
if user_text is None or user_audio_tokens is None:
raise ValueError("user_text and user_audio_tokens are required when input_ids is not provided.")
user_prompt = self.processor.make_user_prompt(user_text, user_audio_tokens)
if include_system_prompt:
system_prompt = self.processor.make_ensemble(self._voice_prompt_tokens)
input_ids = np.concatenate([system_prompt, user_prompt], axis=0)
else:
input_ids = user_prompt
self._turn_input_ids = input_ids
self._turn_idx += 1
self._text_cache = ""
self._pending_tokens = []
self._prefilled = False
self._text_ended = False
self.inferencer.reset_generation_state(keep_cache=not reset_cache)
def push_text_tokens(self, tokens: Iterable[int]) -> list[torch.Tensor]:
self._pending_tokens.extend([int(t) for t in tokens])
return self._drain_pending_tokens()
def push_text(self, text_fragment: str) -> list[torch.Tensor]:
self._text_cache += text_fragment
segments = self._extract_text_segments(force=False)
for segment in segments:
self._pending_tokens.extend(self._tokenize(segment))
return self._drain_pending_tokens()
def end_text(self) -> list[torch.Tensor]:
self._text_ended = True
if self._text_cache:
self._pending_tokens.extend(self._tokenize(self._text_cache))
self._text_cache = ""
return self._drain_pending_tokens()
def drain(self, max_steps: Optional[int] = None) -> list[torch.Tensor]:
if not self._prefilled:
return []
return self.inferencer.finish(
max_steps=max_steps,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
do_sample=self.do_sample,
repetition_penalty=self.repetition_penalty,
repetition_window=self.repetition_window,
)
def _tokenize(self, text: str) -> list[int]:
return self.tokenizer.encode(text, add_special_tokens=False)
def _extract_text_segments(self, force: bool) -> list[str]:
segments = []
if force:
if self._text_cache:
segments.append(self._text_cache)
self._text_cache = ""
return segments
while self._text_cache:
cut_idx = None
if len(self._text_cache) >= self.min_text_chunk_chars:
matches = list(self._split_pattern.finditer(self._text_cache))
for match in matches:
if match.end() >= self.min_text_chunk_chars:
cut_idx = match.end()
break
if cut_idx is None and len(self._text_cache) >= self.text_buffer_size:
whitespace_idx = self._text_cache.rfind(" ")
if whitespace_idx != -1:
cut_idx = whitespace_idx + 1
if cut_idx is None:
break
segments.append(self._text_cache[:cut_idx])
self._text_cache = self._text_cache[cut_idx:]
return segments
def _prefill_if_needed(self) -> list[torch.Tensor]:
if self._prefilled:
return []
if not self._pending_tokens and not self._text_ended:
return []
if len(self._pending_tokens) < self.prefill_text_len and not self._text_ended:
return []
if self._turn_input_ids is None:
raise ValueError("reset_turn must be called before streaming text.")
if self._text_ended:
prefill_len = len(self._pending_tokens)
else:
prefill_len = min(len(self._pending_tokens), self.prefill_text_len)
if prefill_len == 0:
return []
prefix_tokens = [self._pending_tokens.pop(0) for _ in range(prefill_len)]
audio_tokens = self.inferencer.prefill(
input_ids=[self._turn_input_ids],
text_prefix_ids=[prefix_tokens],
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
do_sample=self.do_sample,
repetition_penalty=None,
repetition_window=self.repetition_window,
)
self._prefilled = True
return [audio_tokens]
def _drain_pending_tokens(self) -> list[torch.Tensor]:
outputs: list[torch.Tensor] = []
outputs.extend(self._prefill_if_needed())
if not self._prefilled:
return outputs
while self._pending_tokens and not self.inferencer.is_finished:
token = self._pending_tokens.pop(0)
outputs.append(
self.inferencer.step(
token,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
do_sample=self.do_sample,
repetition_penalty=self.repetition_penalty,
repetition_window=self.repetition_window,
)
)
return outputs
@requires(backends=("torch",))
class AudioStreamDecoder:
"""Decode audio tokens into waveform chunks with optional crossfade."""
def __init__(
self,
codec,
chunk_frames: int = 40,
overlap_frames: int = 4,
decode_kwargs: Optional[dict] = None,
device: Optional[torch.device] = None,
):
self.codec = codec
self.chunk_frames = chunk_frames
self.overlap_frames = overlap_frames
self.decode_kwargs = decode_kwargs or {}
self.device = device
self._buffer: list[torch.Tensor] = []
self._buffer_len = 0
self._prev_tail: Optional[torch.Tensor] = None
def push_tokens(self, audio_tokens: np.ndarray | torch.Tensor):
if isinstance(audio_tokens, np.ndarray):
audio_tokens = torch.from_numpy(audio_tokens)
if audio_tokens.dim() != 2:
raise ValueError(f"Expected [T, C] audio tokens, got {tuple(audio_tokens.shape)}")
self._buffer.append(audio_tokens)
self._buffer_len += audio_tokens.shape[0]
def audio_chunks(self) -> Iterable[torch.Tensor]:
while self._buffer_len >= self.chunk_frames:
chunk_tokens = self._consume_frames(self.chunk_frames)
wav = self._decode(chunk_tokens, chunk_duration=0.32)
yield self._apply_crossfade(wav)
def flush(self) -> Optional[torch.Tensor]:
if self._buffer_len == 0:
return None
chunk_tokens = self._consume_frames(self._buffer_len)
wav = self._decode(chunk_tokens)
return self._apply_crossfade(wav, final_chunk=True)
def _consume_frames(self, num_frames: int) -> torch.Tensor:
frames = []
remaining = num_frames
while remaining > 0 and self._buffer:
head = self._buffer[0]
if head.shape[0] <= remaining:
frames.append(head)
remaining -= head.shape[0]
self._buffer.pop(0)
else:
frames.append(head[:remaining])
self._buffer[0] = head[remaining:]
remaining = 0
self._buffer_len -= num_frames - remaining
return torch.cat(frames, dim=0)
def _decode(self, tokens: torch.Tensor, chunk_duration: float = 0.32) -> torch.Tensor:
device = self.device
if device is None:
if hasattr(self.codec, "device"):
device = self.codec.device
else:
try:
device = next(self.codec.parameters()).device
except Exception:
device = None
if device is not None:
tokens = tokens.to(device)
tokens_t = tokens.permute(1, 0)
# allow callers to override decode settings (e.g. chunk_duration=-1 to disable internal streaming)
decode_kwargs = dict(self.decode_kwargs) if self.decode_kwargs else {}
if "chunk_duration" in decode_kwargs:
override = decode_kwargs.pop("chunk_duration")
if override is None:
chunk_duration_arg = None
else:
try:
override_f = float(override)
except Exception:
override_f = None
chunk_duration_arg = None if override_f is None or override_f <= 0 else override_f
else:
chunk_duration_arg = chunk_duration
decoded = self.codec.decode(tokens_t, chunk_duration=chunk_duration_arg, **decode_kwargs)
if isinstance(decoded, dict):
wav = decoded["audio"][0]
else:
wav = decoded
if isinstance(wav, np.ndarray):
wav = torch.from_numpy(wav)
if wav.dim() > 1:
wav = wav.squeeze(0)
return wav
def _apply_crossfade(self, wav: torch.Tensor, final_chunk: bool = False) -> torch.Tensor:
if self.overlap_frames <= 0:
return wav
if self._prev_tail is None:
self._prev_tail = wav[-self._overlap_samples(wav) :].clone() if not final_chunk else None
return wav
overlap = self._overlap_samples(wav)
if overlap == 0:
return wav
prev_tail = self._prev_tail
if prev_tail.numel() < overlap:
overlap = prev_tail.numel()
if overlap == 0:
return wav
fade_out = torch.linspace(1.0, 0.0, overlap, device=wav.device)
fade_in = 1.0 - fade_out
cross = prev_tail[-overlap:] * fade_out + wav[:overlap] * fade_in
merged = torch.cat([prev_tail[:-overlap], cross, wav[overlap:]], dim=-1)
self._prev_tail = None if final_chunk else wav[-overlap:].clone()
return merged
def _overlap_samples(self, wav: torch.Tensor) -> int:
if self.chunk_frames <= 0:
return 0
return int(wav.numel() * (self.overlap_frames / self.chunk_frames))
class TextDeltaTokenizer:
"""
Convert LLM streaming text (delta) into “incremental token IDs”.
Notes:
- The input is a delta that is progressively appended to the same string
(consistent with the common delta output behavior in vLLM).
- Each time, re-encode the *full text* with the tokenizer, then take only
the newly added token IDs.
- This guarantees that tokenization is consistent with the final complete
text, avoiding boundary mismatches caused by tokenizing partial segments.
"""
def __init__(self, tokenizer, *, hold_back: int = 3):
self.tokenizer = tokenizer
self.hold_back = max(0, int(hold_back))
self._text = ""
self._all_ids: list[int] = []
self._emitted_count: int = 0
@property
def text(self) -> str:
return self._text
@property
def token_ids(self) -> list[int]:
return list(self._all_ids)
def push_delta(self, delta: str) -> list[int]:
"""Append a text delta and return newly stable token ids (may be empty)."""
if not delta:
return []
self._text += str(delta)
self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
# Keep the tail un-emitted because the latest tokens can still change.
stable_count = max(self._emitted_count, len(self._all_ids) - self.hold_back)
new_ids = self._all_ids[self._emitted_count : stable_count]
self._emitted_count = stable_count
return new_ids
def flush(self) -> list[int]:
"""Emit all remaining token ids at end of stream."""
self._all_ids = self.tokenizer.encode(self._text, add_special_tokens=False)
remaining = self._all_ids[self._emitted_count :]
self._emitted_count = len(self._all_ids)
return remaining
def _sanitize_audio_tokens(
tokens: torch.Tensor,
*,
codebook_size: int,
audio_eos_token: int,
) -> tuple[torch.Tensor, bool]:
"""Trim rows after EOS/invalid tokens and return whether decoding should stop."""
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.numel() == 0:
return tokens, False
eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False)
invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1)
stop_idx = None
if eos_rows.numel() > 0:
stop_idx = int(eos_rows[0].item())
if invalid_rows.any():
invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item())
stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
if stop_idx is not None:
return tokens[:stop_idx], True
return tokens, False
def _maybe_codec_streaming(codec, *, batch_size: int):
if codec is None or not hasattr(codec, "streaming"):
return contextlib.nullcontext()
return codec.streaming(batch_size=batch_size)
@requires(backends=("torch",))
class MossTTSRealtimeTextStreamBridge:
"""
Bridge: external LLM streaming text (delta) -> TTS streaming audio chunks.
Usage overview:
- First configure `MossTTSRealtimeStreamingSession` (especially `prefill_text_len=12`).
- Provide an `AudioStreamDecoder`, then continuously feed the LLM delta text via
`push_text_delta()`.
- Once the accumulated token count reaches `prefill_text_len`, the session will
start generating audio tokens; the bridge will immediately decode them into WAV
chunks and yield them.
"""
def __init__(
self,
session: MossTTSRealtimeStreamingSession,
decoder: AudioStreamDecoder,
*,
codebook_size: Optional[int] = None,
audio_eos_token: Optional[int] = None,
batch_size: int = 1,
):
self.session = session
self.decoder = decoder
self.batch_size = int(batch_size)
if codebook_size is None:
codebook_size = int(getattr(getattr(session, "codec", None), "codebook_size", 1024))
if audio_eos_token is None:
audio_eos_token = int(getattr(session.inferencer, "audio_eos_token", 1026))
self.codebook_size = int(codebook_size)
self.audio_eos_token = int(audio_eos_token)
def push_text_delta(self, delta: str) -> Iterator[torch.Tensor]:
"""
Push a chunk of incremental text output from the LLM and return newly generated WAV chunks.
Internally, this directly calls `session.push_text()`, which segments the text
based on punctuation/length and then tokenizes the *entire segment* at once,
avoiding the prefix instability issues of incremental BPE tokenization.
"""
audio_frames = self.session.push_text(delta)
yield from self._decode_audio_frames(audio_frames)
def push_text_tokens(self, token_ids: Sequence[int]) -> Iterator[torch.Tensor]:
"""Push token ids directly (for sources that stream token ids)."""
if not token_ids:
return
audio_frames = self.session.push_text_tokens(token_ids)
yield from self._decode_audio_frames(audio_frames)
def finish(self, *, drain_step: int = 1) -> Iterator[torch.Tensor]:
"""Mark text stream end and emit all remaining audio chunks (including flush)."""
audio_frames = self.session.end_text()
yield from self._decode_audio_frames(audio_frames)
while True:
more_frames = self.session.drain(max_steps=drain_step)
if not more_frames:
break
yield from self._decode_audio_frames(more_frames)
if self.session.inferencer.is_finished:
break
final = self.decoder.flush()
if final is not None and final.numel() > 0:
yield final.detach().cpu()
def stream_from_text_deltas(self, deltas: Iterable[str], *, drain_step: int = 1) -> Iterator[torch.Tensor]:
"""Consume a full delta iterator and continuously yield waveform chunks."""
with _maybe_codec_streaming(getattr(self.session, "codec", None), batch_size=self.batch_size):
for delta in deltas:
yield from self.push_text_delta(delta)
yield from self.finish(drain_step=drain_step)
def _decode_audio_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[torch.Tensor]:
for frame in audio_frames:
tokens = frame
if tokens.dim() == 3:
tokens = tokens[0]
if tokens.dim() != 2:
raise ValueError(f"Expected [B, C] or [1, C] audio tokens, got {tuple(tokens.shape)}")
if tokens.shape[0] != 1:
raise ValueError(
f"This bridge currently supports batch_size=1 for decoding, got batch={tokens.shape[0]}."
)
tokens, stop = _sanitize_audio_tokens(
tokens,
codebook_size=self.codebook_size,
audio_eos_token=self.audio_eos_token,
)
if tokens.numel() == 0:
if stop:
break
continue
self.decoder.push_tokens(tokens.detach())
for wav in self.decoder.audio_chunks():
if wav.numel() == 0:
continue
yield wav.detach().cpu()
if stop:
break
__all__ = [
"AudioStreamDecoder",
"MossTTSRealtimeInference",
"MossTTSRealtimeStreamingSession",
"MossTTSRealtimeTextStreamBridge",
"TextDeltaTokenizer",
]