Zhyw's picture
Update app.py
7ab64cf verified
import spaces
import argparse
import base64
import functools
import json
import sys
import threading
import time
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator, Sequence
import gradio as gr
import numpy as np
import os
os.environ["TORCHDYNAMO_DISABLE"] = "1"
import torch
import torchaudio
import torch._dynamo
from transformers import AutoModel, AutoTokenizer
from mossttsrealtime import MossTTSRealtime, MossTTSRealtimeProcessor
from mossttsrealtime.streaming_mossttsrealtime import (
AudioStreamDecoder,
MossTTSRealtimeInference,
MossTTSRealtimeStreamingSession,
)
torch._dynamo.config.cache_size_limit = 64
APP_DIR = Path(__file__).resolve().parent
AUDIO_DIR = APP_DIR / "asset"
LOG_DIR = APP_DIR / "logs"
SAMPLE_RATE = 24000
CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
MODEL_PATH = "OpenMOSS-Team/MOSS-TTS-Realtime"
TOKENIZER_PATH = "OpenMOSS-Team/MOSS-TTS-Realtime"
PROMPT_WAV = "asset/prompt_audio.mp3"
USER_WAV = "asset/user1.wav"
WARMUP_POLL_INTERVAL_SECONDS = 0.5
DEFAULT_REPETITION_WINDOW = 50
WARMUP_STEP_TOKENS = DEFAULT_REPETITION_WINDOW + 1
WARMUP_USER_TEXT = "Hello!"
WARMUP_BASE_ASSISTANT_TEXT = (
"This startup warmup request primes the streaming text to speech path "
"so the first real user request avoids the cold compile stall."
)
def _apply_seed(seed: int | None) -> None:
if seed is None:
return
# ZeroGPU: avoid touching torch.cuda outside the managed GPU call.
torch.manual_seed(seed)
def _load_audio(path: Path, target_sample_rate: int = SAMPLE_RATE) -> torch.Tensor:
wav, sr = torchaudio.load(path)
if sr != target_sample_rate:
wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
return wav
def _load_codec(device: torch.device, codec_model_path: str):
codec = AutoModel.from_pretrained(codec_model_path, trust_remote_code=True).eval()
return codec.to(device)
def _extract_codes(encode_result):
if isinstance(encode_result, dict):
codes = encode_result["audio_codes"]
elif isinstance(encode_result, (list, tuple)) and encode_result:
codes = encode_result[0]
else:
codes = encode_result
if isinstance(codes, np.ndarray):
codes = torch.from_numpy(codes)
if isinstance(codes, torch.Tensor) and codes.dim() == 3:
if codes.shape[1] == 1:
codes = codes[:, 0, :]
elif codes.shape[0] == 1:
codes = codes[0]
else:
raise ValueError(f"Unsupported 3D audio code shape: {tuple(codes.shape)}")
return codes
@dataclass(frozen=True)
class BackendPaths:
model_path: str
tokenizer_path: str
codec_model_path: str
device_str: str
attn_impl: str
@dataclass(frozen=True)
class GenerationConfig:
temperature: float
top_p: float
top_k: int
repetition_penalty: float
repetition_window: int
do_sample: bool
max_length: int
seed: int | None
@dataclass(frozen=True)
class StreamingConfig:
text_chunk_tokens: int
input_delay: float
decode_chunk_frames: int
decode_overlap_frames: int
chunk_duration: float
prebuffer_seconds: float
buffer_threshold_seconds: float = 0.0
@dataclass(frozen=True)
class StreamingRequest:
user_text: str
assistant_text: str
prompt_audio: str | None
user_audio: str | None
use_default_prompt: bool
use_default_user: bool
generation: GenerationConfig
streaming: StreamingConfig
backend: BackendPaths
@dataclass(frozen=True)
class StreamEvent:
message: str
audio: tuple[int, np.ndarray] | None = None
@dataclass(frozen=True)
class WarmupSnapshot:
state: str
progress: float
message: str
detail: str | None = None
error: str | None = None
@property
def ready(self) -> bool:
return self.state == "ready"
@property
def failed(self) -> bool:
return self.state == "failed"
def _make_log_path(prefix: str) -> Path:
LOG_DIR.mkdir(parents=True, exist_ok=True)
stamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
return LOG_DIR / f"{prefix}_{stamp}_{time.time_ns() % 1_000_000_000:09d}.jsonl"
def _compute_rtf_metrics(sample_count: int, sample_rate: int, started_at: float) -> dict[str, float | None]:
elapsed_s = max(0.0, time.monotonic() - started_at)
audio_s = float(sample_count) / float(sample_rate) if sample_count > 0 and sample_rate > 0 else 0.0
rtf = (elapsed_s / audio_s) if audio_s > 0 else None
return {
"elapsed_s": elapsed_s,
"audio_s": audio_s,
"rtf": rtf,
}
class StreamRTFLogger:
def __init__(self, path: Path, started_at: float):
self.path = path
self.started_at = started_at
self.chunk_count = 0
self.sample_rate = SAMPLE_RATE
self.samples_emitted = 0
@classmethod
def create(cls, request: "StreamingRequest", started_at: float) -> "StreamRTFLogger":
logger = cls(_make_log_path("rtf"), started_at)
logger.log_request_started(request)
print(f"[MossTTSRealtime][rtf-log] {logger.path}", flush=True)
return logger
def log_request_started(self, request: "StreamingRequest") -> None:
self._append(
{
"event": "request_started",
"user_text_chars": len(request.user_text),
"assistant_text_chars": len(request.assistant_text),
"text_chunk_tokens": request.streaming.text_chunk_tokens,
"decode_chunk_frames": request.streaming.decode_chunk_frames,
"decode_overlap_frames": request.streaming.decode_overlap_frames,
"chunk_duration_s": request.streaming.chunk_duration,
"prebuffer_seconds": request.streaming.prebuffer_seconds,
"temperature": request.generation.temperature,
"top_p": request.generation.top_p,
"top_k": request.generation.top_k,
"repetition_penalty": request.generation.repetition_penalty,
"repetition_window": request.generation.repetition_window,
"do_sample": request.generation.do_sample,
"max_length": request.generation.max_length,
"seed": request.generation.seed,
"device": request.backend.device_str,
"attn_implementation": request.backend.attn_impl,
}
)
def log_chunk(
self,
*,
event_message: str,
sample_rate: int,
chunk: np.ndarray,
first_audio_time: float | None,
) -> None:
chunk = np.asarray(chunk).reshape(-1)
if chunk.size == 0:
return
self.chunk_count += 1
self.sample_rate = int(sample_rate)
self.samples_emitted += int(chunk.size)
metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at)
record = {
"event": "stream_chunk",
"message": event_message,
"chunk_idx": self.chunk_count,
"chunk_audio_s": float(chunk.size) / float(self.sample_rate),
"audio_s_emitted": metrics["audio_s"],
"elapsed_s": metrics["elapsed_s"],
"rtf": metrics["rtf"],
}
if first_audio_time is not None:
record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0)
self._append(record)
def log_completion(self, *, first_audio_time: float | None) -> None:
metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at)
record = {
"event": "stream_complete",
"chunk_count": self.chunk_count,
"audio_s_total": metrics["audio_s"],
"elapsed_s": metrics["elapsed_s"],
"rtf": metrics["rtf"],
}
if first_audio_time is not None:
record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0)
self._append(record)
def log_no_audio(self) -> None:
metrics = _compute_rtf_metrics(0, self.sample_rate, self.started_at)
self._append(
{
"event": "stream_complete",
"chunk_count": 0,
"audio_s_total": 0.0,
"elapsed_s": metrics["elapsed_s"],
"rtf": None,
"warning": "No audio chunks emitted.",
}
)
def log_error(self, exc: Exception, *, first_audio_time: float | None) -> None:
metrics = _compute_rtf_metrics(self.samples_emitted, self.sample_rate, self.started_at)
record = {
"event": "stream_error",
"error_type": type(exc).__name__,
"error": str(exc),
"chunk_count": self.chunk_count,
"audio_s_emitted": metrics["audio_s"],
"elapsed_s": metrics["elapsed_s"],
"rtf": metrics["rtf"],
}
if first_audio_time is not None:
record["time_to_first_audio_ms"] = max(0.0, (first_audio_time - self.started_at) * 1000.0)
self._append(record)
def _append(self, payload: dict[str, object]) -> None:
record = {
"ts": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
**payload,
}
with self.path.open("a", encoding="utf-8") as handle:
handle.write(json.dumps(record, ensure_ascii=False) + "\n")
class TokenChunkStream:
def __init__(
self,
tokens: Sequence[int],
chunk_size: int,
):
self._tokens = list(tokens)
self._chunk_size = int(chunk_size)
def __iter__(self) -> Iterator[list[int]]:
if not self._tokens:
return
step = len(self._tokens) if self._chunk_size <= 0 else self._chunk_size
for idx in range(0, len(self._tokens), step):
yield self._tokens[idx : idx + step]
class BufferedAudioTracker:
def __init__(self, sample_rate: int):
self.sample_rate = sample_rate
self.start_time: float | None = None
self.samples_emitted = 0
def add_chunk(self, chunk: np.ndarray) -> None:
if chunk.size == 0:
return
if self.start_time is None:
self.start_time = time.monotonic()
self.samples_emitted += int(chunk.size)
def buffered_seconds(self) -> float:
if self.start_time is None:
return 0.0
elapsed = time.monotonic() - self.start_time
buffered = self.samples_emitted / self.sample_rate - elapsed
return max(0.0, buffered)
class AudioFrameDecoder:
def __init__(
self,
decoder: AudioStreamDecoder,
codebook_size: int,
audio_eos_token: int,
):
self.decoder = decoder
self.codebook_size = codebook_size
self.audio_eos_token = audio_eos_token
def decode_frames(self, audio_frames: list[torch.Tensor]) -> Iterator[np.ndarray]:
for frame in audio_frames:
tokens = frame
if tokens.dim() == 3:
tokens = tokens[0]
if tokens.dim() != 2:
raise ValueError(f"Expected [T, C] audio tokens, got {tuple(tokens.shape)}")
tokens, stop = _sanitize_tokens(tokens, self.codebook_size, self.audio_eos_token)
if tokens.numel() == 0:
if stop:
break
continue
self.decoder.push_tokens(tokens.detach())
for wav in self.decoder.audio_chunks():
if wav.numel() == 0:
continue
yield wav.detach().cpu().numpy().reshape(-1)
if stop:
break
def flush(self) -> Iterator[np.ndarray]:
final_chunk = self.decoder.flush()
if final_chunk is not None and final_chunk.numel() > 0:
yield final_chunk.detach().cpu().numpy().reshape(-1)
class StreamAudioEmitter:
def __init__(self, sample_rate: int, prebuffer_seconds: float):
self.sample_rate = sample_rate
self._buffer_tracker = BufferedAudioTracker(sample_rate)
self._prebuffer_target = max(0.0, float(prebuffer_seconds))
self._prebuffering = self._prebuffer_target > 0.0
self._pending_chunks: list[np.ndarray] = []
self._pending_samples = 0
self.chunk_count = 0
self.has_audio = False
def wait_for_capacity(self, threshold_seconds: float) -> None:
_maybe_wait_for_buffer(self._buffer_tracker, threshold_seconds)
def emit_many(self, chunks: Iterator[np.ndarray], message_prefix: str) -> Iterator[StreamEvent]:
for chunk in chunks:
yield from self.emit(chunk, message_prefix)
def emit(self, chunk: np.ndarray, message_prefix: str) -> Iterator[StreamEvent]:
chunk = np.asarray(chunk).reshape(-1)
if chunk.size == 0:
return
if self._prebuffering:
self._pending_chunks.append(chunk)
self._pending_samples += int(chunk.size)
if (self._pending_samples / self.sample_rate) < self._prebuffer_target:
return
self._prebuffering = False
pending_chunks = self._pending_chunks
self._pending_chunks = []
self._pending_samples = 0
for pending in pending_chunks:
yield self._make_event(pending, message_prefix)
return
yield self._make_event(chunk, message_prefix)
def flush(self, message_prefix: str) -> Iterator[StreamEvent]:
if not self._prebuffering or not self._pending_chunks:
self._prebuffering = False
return
self._prebuffering = False
pending_chunks = self._pending_chunks
self._pending_chunks = []
self._pending_samples = 0
for chunk in pending_chunks:
yield self._make_event(chunk, message_prefix)
def _make_event(self, chunk: np.ndarray, message_prefix: str) -> StreamEvent:
self.chunk_count += 1
self.has_audio = True
self._buffer_tracker.add_chunk(chunk)
return StreamEvent(
message=f"{message_prefix} chunk {self.chunk_count}",
audio=(self.sample_rate, chunk),
)
def _maybe_wait_for_buffer(buffer_tracker: BufferedAudioTracker, threshold_seconds: float) -> None:
if threshold_seconds <= 0:
return
while buffer_tracker.buffered_seconds() > threshold_seconds:
time.sleep(0.01)
def _sanitize_tokens(
tokens: torch.Tensor,
codebook_size: int,
audio_eos_token: int,
) -> tuple[torch.Tensor, bool]:
if tokens.dim() == 1:
tokens = tokens.unsqueeze(0)
if tokens.numel() == 0:
return tokens, False
eos_rows = (tokens[:, 0] == audio_eos_token).nonzero(as_tuple=False)
invalid_rows = ((tokens < 0) | (tokens >= codebook_size)).any(dim=1)
stop_idx = None
if eos_rows.numel() > 0:
stop_idx = int(eos_rows[0].item())
if invalid_rows.any():
invalid_idx = int(invalid_rows.nonzero(as_tuple=False)[0].item())
stop_idx = invalid_idx if stop_idx is None else min(stop_idx, invalid_idx)
if stop_idx is not None:
tokens = tokens[:stop_idx]
return tokens, True
return tokens, False
def _build_streaming_session(
model: MossTTSRealtime,
tokenizer,
processor: MossTTSRealtimeProcessor,
codec,
*,
max_length: int,
chunk_duration: float,
temperature: float,
top_p: float,
top_k: int,
do_sample: bool,
repetition_penalty: float,
repetition_window: int,
) -> tuple[MossTTSRealtimeStreamingSession, MossTTSRealtimeInference]:
inferencer = MossTTSRealtimeInference(model, tokenizer, max_length=max_length)
inferencer.reset_generation_state(keep_cache=False)
session = MossTTSRealtimeStreamingSession(
inferencer,
processor,
codec=codec,
codec_sample_rate=SAMPLE_RATE,
codec_encode_kwargs={"chunk_duration": chunk_duration},
prefill_text_len=processor.delay_tokens_len,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
repetition_window=repetition_window,
)
return session, inferencer
def _build_frame_decoder(
codec,
inferencer: MossTTSRealtimeInference,
device: torch.device,
*,
chunk_frames: int,
overlap_frames: int,
) -> AudioFrameDecoder:
decoder = AudioStreamDecoder(
codec,
chunk_frames=chunk_frames,
overlap_frames=overlap_frames,
decode_kwargs={"chunk_duration": -1},
device=device,
)
return AudioFrameDecoder(
decoder,
int(getattr(codec, "codebook_size", 1024)),
int(getattr(inferencer, "audio_eos_token", 1026)),
)
def _normalize_seed(value: float | int | None) -> int | None:
if value is None:
return None
seed = int(value)
return None if seed == 0 else seed
def _format_completion_status(
chunk_count: int,
sample_rate: int,
full_audio: np.ndarray,
started_at: float,
first_audio_time: float | None,
) -> str:
elapsed = time.monotonic() - started_at
audio_seconds = float(full_audio.size) / float(sample_rate) if full_audio.size > 0 else 0.0
rtf = (elapsed / audio_seconds) if audio_seconds > 0 else float("inf")
parts = [
"Done",
]
return " | ".join(parts)
@functools.lru_cache(maxsize=1)
def _load_backend(
model_path: str,
tokenizer_path: str,
codec_model_path: str,
device_str: str,
attn_impl: str,
):
# ZeroGPU: do not call torch.cuda.is_available() here; it may trigger low-level CUDA init.
device = torch.device(device_str)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
processor = MossTTSRealtimeProcessor(tokenizer)
# ZeroGPU: avoid torch.cuda.is_bf16_supported() before CUDA is fully managed.
dtype = torch.bfloat16
if attn_impl and attn_impl.lower() not in {"none", ""}:
model = MossTTSRealtime.from_pretrained(model_path, attn_implementation=attn_impl, torch_dtype=dtype).to(device)
if (
attn_impl.lower() == "flash_attention_2"
and hasattr(model, "language_model")
and hasattr(model.language_model, "config")
):
model.language_model.config.attn_implementation = "flash_attention_2"
else:
model = MossTTSRealtime.from_pretrained(model_path, torch_dtype=dtype).to(device)
model.eval()
codec = _load_codec(device, codec_model_path)
return model, tokenizer, processor, codec, device
def _resolve_audio_path(audio_path: str | None, use_default: bool, default_path: str | Path) -> Path | None:
if audio_path:
return Path(audio_path).expanduser()
if use_default:
return Path(default_path).expanduser()
return None
class StreamingTTSDemo:
def __init__(self, audio_token_cache_size: int = 8):
self._audio_token_cache_size = max(1, int(audio_token_cache_size))
self._audio_token_cache: OrderedDict[tuple[str, int, float], np.ndarray] = OrderedDict()
def get_or_load_backend(self, backend: BackendPaths):
return _load_backend(
backend.model_path,
backend.tokenizer_path,
backend.codec_model_path,
backend.device_str,
backend.attn_impl,
)
def _validate_request(self, request: StreamingRequest) -> tuple[Path | None, Path | None]:
if not request.user_text.strip():
raise ValueError("user_text is required.")
if not request.assistant_text.strip():
raise ValueError("assistant_text is required.")
if request.streaming.text_chunk_tokens <= 0:
raise ValueError("text_chunk_tokens must be greater than 0.")
if request.streaming.decode_chunk_frames <= 0:
raise ValueError("decode_chunk_frames must be greater than 0.")
if request.streaming.chunk_duration <= 0:
raise ValueError("chunk_duration must be greater than 0.")
prompt_path = _resolve_audio_path(request.prompt_audio, request.use_default_prompt, PROMPT_WAV)
user_path = _resolve_audio_path(request.user_audio, request.use_default_user, USER_WAV)
if prompt_path is not None and not prompt_path.exists():
raise FileNotFoundError(f"Prompt wav not found: {prompt_path}")
if user_path is not None and not user_path.exists():
raise FileNotFoundError(f"User wav not found: {user_path}")
return prompt_path, user_path
def _encode_audio_tokens(
self,
path: Path,
codec,
device: torch.device,
chunk_duration: float,
) -> np.ndarray:
resolved_path = path.expanduser().resolve()
cache_key = (str(resolved_path), int(resolved_path.stat().st_mtime_ns), float(chunk_duration))
cached_tokens = self._audio_token_cache.get(cache_key)
if cached_tokens is not None:
self._audio_token_cache.move_to_end(cache_key)
return cached_tokens
with torch.inference_mode():
audio_tensor = _load_audio(resolved_path)
waveform = audio_tensor.to(device)
if waveform.dim() == 2:
waveform = waveform.unsqueeze(0)
encode_result = codec.encode(waveform, chunk_duration=chunk_duration)
tokens = _extract_codes(encode_result)
if isinstance(tokens, torch.Tensor):
tokens = tokens.detach().cpu().numpy()
else:
tokens = np.asarray(tokens)
self._audio_token_cache[cache_key] = tokens
self._audio_token_cache.move_to_end(cache_key)
while len(self._audio_token_cache) > self._audio_token_cache_size:
self._audio_token_cache.popitem(last=False)
return tokens
@staticmethod
def _build_text_only_turn_input(
processor: MossTTSRealtimeProcessor,
user_text: str,
prompt_tokens: np.ndarray | None,
) -> np.ndarray:
system_prompt = processor.make_ensemble(prompt_tokens)
user_prompt_text = "<|im_end|>\n<|im_start|>user\n" + user_text + "<|im_end|>\n<|im_start|>assistant\n"
user_prompt_tokens = processor.tokenizer(user_prompt_text)["input_ids"]
user_prompt = np.full(
shape=(len(user_prompt_tokens), processor.channels + 1),
fill_value=processor.audio_channel_pad,
dtype=np.int64,
)
user_prompt[:, 0] = np.asarray(user_prompt_tokens, dtype=np.int64)
return np.concatenate([system_prompt, user_prompt], axis=0)
def _prepare_session_turn(
self,
session: MossTTSRealtimeStreamingSession,
processor: MossTTSRealtimeProcessor,
user_text: str,
prompt_tokens: np.ndarray | None,
user_tokens: np.ndarray | None,
) -> str | None:
if user_tokens is None:
turn_input_ids = self._build_text_only_turn_input(processor, user_text, prompt_tokens)
session.reset_turn(input_ids=turn_input_ids, include_system_prompt=True, reset_cache=True)
return "No user audio provided, running text-only turn."
session.reset_turn(
user_text=user_text,
user_audio_tokens=user_tokens,
include_system_prompt=True,
reset_cache=True,
)
return None
def run_stream(self, request: StreamingRequest) -> Iterator[StreamEvent]:
prompt_path, user_path = self._validate_request(request)
model, tokenizer, processor, codec, device = self.get_or_load_backend(request.backend)
_apply_seed(request.generation.seed)
prompt_tokens = (
self._encode_audio_tokens(
prompt_path,
codec,
device,
chunk_duration=request.streaming.chunk_duration,
)
if prompt_path is not None
else None
)
user_tokens = (
self._encode_audio_tokens(
user_path,
codec,
device,
chunk_duration=request.streaming.chunk_duration,
)
if user_path is not None
else None
)
session, inferencer = _build_streaming_session(
model,
tokenizer,
processor,
codec,
max_length=request.generation.max_length,
chunk_duration=request.streaming.chunk_duration,
temperature=request.generation.temperature,
top_p=request.generation.top_p,
top_k=request.generation.top_k,
do_sample=request.generation.do_sample,
repetition_penalty=request.generation.repetition_penalty,
repetition_window=request.generation.repetition_window,
)
if prompt_tokens is not None:
session.set_voice_prompt_tokens(prompt_tokens)
else:
session.clear_voice_prompt()
turn_message = self._prepare_session_turn(
session,
processor,
request.user_text,
prompt_tokens,
user_tokens,
)
if turn_message:
yield StreamEvent(message=turn_message)
frame_decoder = _build_frame_decoder(
codec,
inferencer,
device,
chunk_frames=request.streaming.decode_chunk_frames,
overlap_frames=request.streaming.decode_overlap_frames,
)
text_tokens = tokenizer.encode(request.assistant_text, add_special_tokens=False)
if not text_tokens:
raise RuntimeError("Assistant text tokenization returned no tokens.")
token_stream = TokenChunkStream(text_tokens, request.streaming.text_chunk_tokens)
audio_emitter = StreamAudioEmitter(SAMPLE_RATE, request.streaming.prebuffer_seconds)
with codec.streaming(batch_size=1):
for token_chunk in token_stream:
audio_emitter.wait_for_capacity(request.streaming.buffer_threshold_seconds)
audio_frames = session.push_text_tokens(token_chunk)
yield from audio_emitter.emit_many(frame_decoder.decode_frames(audio_frames), "Streaming")
if request.streaming.input_delay > 0:
time.sleep(request.streaming.input_delay)
final_frames = session.end_text()
yield from audio_emitter.emit_many(frame_decoder.decode_frames(final_frames), "Finalizing")
while True:
drain_frames = session.drain(max_steps=1)
if not drain_frames:
break
yield from audio_emitter.emit_many(frame_decoder.decode_frames(drain_frames), "Finalizing")
if session.inferencer.is_finished:
break
yield from audio_emitter.emit_many(frame_decoder.flush(), "Final")
yield from audio_emitter.flush("Final")
if not audio_emitter.has_audio:
raise RuntimeError("No audio waveform chunks decoded from streaming inference.")
yield StreamEvent(message="Streaming complete.")
class WarmupManager:
def __init__(self, tts_demo: "StreamingTTSDemo", backend: BackendPaths):
self.tts_demo = tts_demo
self.backend = backend
self._lock = threading.Lock()
self._thread: threading.Thread | None = None
self._started = False
# ZeroGPU: startup warmup is disabled because it initializes CUDA outside @spaces.GPU.
self._state = "ready"
self._progress = 1.0
self._message = "Ready."
self._detail = "Startup warmup disabled for ZeroGPU; the first generation will load the model."
self._error: str | None = None
def start(self) -> None:
with self._lock:
if self._started:
return
self._started = True
self._thread = threading.Thread(target=self._run, name="tts-startup-warmup", daemon=True)
self._thread.start()
def snapshot(self) -> WarmupSnapshot:
with self._lock:
return WarmupSnapshot(
state=self._state,
progress=self._progress,
message=self._message,
detail=self._detail,
error=self._error,
)
def _set_state(
self,
*,
state: str | None = None,
progress: float | None = None,
message: str | None = None,
detail: str | None = None,
error: str | None = None,
) -> None:
with self._lock:
if state is not None:
self._state = state
if progress is not None:
self._progress = max(0.0, min(1.0, float(progress)))
if message is not None:
self._message = message
if detail is not None:
self._detail = detail
self._error = error
@staticmethod
def _consume_audio(chunks: Iterator[np.ndarray]) -> None:
for _chunk in chunks:
pass
@staticmethod
def _ensure_warmup_text(tokenizer, minimum_tokens: int) -> tuple[str, list[int]]:
text = WARMUP_BASE_ASSISTANT_TEXT
tokens = tokenizer.encode(text, add_special_tokens=False)
while len(tokens) < minimum_tokens:
text = f"{text} {WARMUP_BASE_ASSISTANT_TEXT}"
tokens = tokenizer.encode(text, add_special_tokens=False)
return text, tokens
@staticmethod
def _warmup_step_detail(step_idx: int, total_steps: int) -> str:
if step_idx == 1:
return "First incremental step is compiling the cold streaming path."
if step_idx == 2:
return "Second incremental step is warming the next steady-state path."
if step_idx == DEFAULT_REPETITION_WINDOW:
return "Warming the first full repetition-window step."
if step_idx == WARMUP_STEP_TOKENS:
return "Confirming the post-window steady-state step."
return f"Warming token step {step_idx}/{total_steps}."
def _run(self) -> None:
try:
self._set_state(
state="running",
progress=0.02,
message="Starting startup warmup.",
detail="Preparing backend state for the first real request.",
error=None,
)
self._set_state(
progress=0.08,
message="Loading backend.",
detail="Model, tokenizer, codec, and CUDA runtime are warming up.",
error=None,
)
model, tokenizer, processor, codec, device = self.tts_demo.get_or_load_backend(self.backend)
self._set_state(
progress=0.32,
message="Preparing streaming session.",
detail="Building a text-only warmup turn and its decoder.",
error=None,
)
session, inferencer = _build_streaming_session(
model,
tokenizer,
processor,
codec,
max_length=256,
chunk_duration=0.24,
temperature=0.8,
top_p=0.6,
top_k=30,
do_sample=True,
repetition_penalty=1.1,
repetition_window=DEFAULT_REPETITION_WINDOW,
)
session.clear_voice_prompt()
session.reset_turn(
input_ids=self.tts_demo._build_text_only_turn_input(processor, WARMUP_USER_TEXT, None),
include_system_prompt=True,
reset_cache=True,
)
frame_decoder = _build_frame_decoder(
codec,
inferencer,
device,
chunk_frames=WARMUP_STEP_TOKENS,
overlap_frames=0,
)
_, warmup_tokens = self._ensure_warmup_text(
tokenizer,
processor.delay_tokens_len + WARMUP_STEP_TOKENS,
)
with codec.streaming(batch_size=1):
self._set_state(
progress=0.45,
message="Running prefill.",
detail="Building the first KV cache and warming the backbone path.",
error=None,
)
prefill_frames = session.push_text_tokens(warmup_tokens[: processor.delay_tokens_len])
self._consume_audio(frame_decoder.decode_frames(prefill_frames))
step_tokens = warmup_tokens[
processor.delay_tokens_len : processor.delay_tokens_len + WARMUP_STEP_TOKENS
]
total_steps = max(1, len(step_tokens))
for idx, token in enumerate(step_tokens, start=1):
self._set_state(
progress=0.55 + 0.25 * (idx - 1) / total_steps,
message="Compiling first streaming steps.",
detail=self._warmup_step_detail(idx, total_steps),
error=None,
)
step_frames = session.push_text_tokens([token])
self._consume_audio(frame_decoder.decode_frames(step_frames))
self._set_state(
progress=0.86,
message="Warming finalization path.",
detail="Priming end-text, drain, and decoder flush before user traffic.",
error=None,
)
final_frames = session.end_text()
self._consume_audio(frame_decoder.decode_frames(final_frames))
drain_frames = session.drain(max_steps=1)
self._consume_audio(frame_decoder.decode_frames(drain_frames))
self._consume_audio(frame_decoder.flush())
self._set_state(
state="ready",
progress=1.0,
message="Warmup complete.",
detail="The first real request should avoid the cold-start stall.",
error=None,
)
except Exception as exc:
self._set_state(
state="failed",
progress=1.0,
message="Warmup failed.",
detail="The app did not finish startup warmup.",
error=str(exc),
)
print(f"[MossTTSRealtime][warmup-error] {exc}", file=sys.stderr, flush=True)
def _warmup_button_update(snapshot: WarmupSnapshot):
if snapshot.ready:
return gr.update(value="Generate", interactive=True)
if snapshot.failed:
return gr.update(value="Warmup Failed", interactive=False)
return gr.update(value="Warming Up...", interactive=False)
def _warmup_gate_message(snapshot: WarmupSnapshot) -> str:
progress_pct = int(round(max(0.0, min(1.0, snapshot.progress)) * 100.0))
if snapshot.failed:
return f"Warmup failed: {snapshot.error or snapshot.message}"
return f"Warmup in progress ({progress_pct}%): {snapshot.message}"
def _status_from_snapshot(snapshot: WarmupSnapshot) -> str:
return "Ready." if snapshot.ready else _warmup_gate_message(snapshot)
def _warmup_status_update(snapshot: WarmupSnapshot):
return gr.update(value=_status_from_snapshot(snapshot))
def _warmup_timer_update(snapshot: WarmupSnapshot):
return gr.update(active=not (snapshot.ready or snapshot.failed))
def _encode_chunk(sr: int, chunk: np.ndarray, idx: int) -> str:
if chunk.dtype != np.float32:
chunk = chunk.astype(np.float32)
if chunk.ndim != 1:
chunk = chunk.reshape(-1)
payload = {
"sr": int(sr),
"idx": int(idx),
"data": base64.b64encode(chunk.tobytes()).decode("ascii"),
}
return json.dumps(payload)
def _build_request(
args: argparse.Namespace,
*,
user_text: str | None,
assistant_text: str | None,
prompt_audio: str | None,
user_audio: str | None,
use_default_prompt: bool,
use_default_user: bool,
temperature: float,
top_p: float,
top_k: int,
repetition_penalty: float,
repetition_window: int,
do_sample: bool,
max_length: int,
seed: float | int | None,
text_chunk_tokens: int,
input_delay: float,
decode_chunk_frames: int,
decode_overlap_frames: int,
chunk_duration: float,
prebuffer_seconds: float,
) -> StreamingRequest:
return StreamingRequest(
user_text=str(user_text or "Hello!"),
assistant_text=str(assistant_text or ""),
prompt_audio=prompt_audio,
user_audio=user_audio,
use_default_prompt=use_default_prompt,
use_default_user=use_default_user,
generation=GenerationConfig(
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
repetition_penalty=float(repetition_penalty),
repetition_window=int(repetition_window),
do_sample=bool(do_sample),
max_length=int(max_length),
seed=_normalize_seed(seed),
),
streaming=StreamingConfig(
text_chunk_tokens=int(text_chunk_tokens),
input_delay=float(input_delay),
decode_chunk_frames=int(decode_chunk_frames),
decode_overlap_frames=int(decode_overlap_frames),
chunk_duration=float(chunk_duration),
prebuffer_seconds=float(prebuffer_seconds),
),
backend=BackendPaths(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path,
codec_model_path=args.codec_model_path,
device_str=args.device,
attn_impl=args.attn_implementation,
),
)
STREAM_PLAYER_HTML = """
<style>
#pcm_stream {
position: absolute !important;
left: -9999px !important;
width: 1px !important;
height: 1px !important;
opacity: 0 !important;
pointer-events: none !important;
}
#pcm_stream textarea, #pcm_stream input {
width: 1px !important;
height: 1px !important;
opacity: 0 !important;
}
</style>
"""
STREAM_PLAYER_JS = r"""
const elemId = "pcm_stream";
if (window.__pcm_streaming_inited__) {
return;
}
window.__pcm_streaming_inited__ = true;
let audioCtx = null;
let nextTime = 0;
let lastIdx = -1;
let lastValue = "";
let boundField = null;
let usingSetterHook = false;
const FADE_MS = 6;
const MIN_BUFFER_SEC = 0.25;
function initAudio(sr) {
if (audioCtx && audioCtx.sampleRate !== sr) {
audioCtx.close();
audioCtx = null;
}
if (!audioCtx) {
audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: sr });
nextTime = audioCtx.currentTime;
}
if (audioCtx.state === "suspended") {
audioCtx.resume();
}
}
function decodeBase64ToFloat32(base64) {
const binary = atob(base64);
const len = binary.length;
const bytes = new Uint8Array(len);
for (let i = 0; i < len; i++) {
bytes[i] = binary.charCodeAt(i);
}
return new Float32Array(bytes.buffer);
}
function playChunk(samples, sr, idx) {
initAudio(sr);
const buffer = audioCtx.createBuffer(1, samples.length, sr);
buffer.copyToChannel(samples, 0);
const source = audioCtx.createBufferSource();
source.buffer = buffer;
const gain = audioCtx.createGain();
source.connect(gain);
gain.connect(audioCtx.destination);
const now = audioCtx.currentTime;
if (nextTime < now + MIN_BUFFER_SEC) {
nextTime = now + MIN_BUFFER_SEC;
}
const startTime = Math.max(now, nextTime);
const endTime = startTime + buffer.duration;
const fade = Math.min(FADE_MS / 1000.0, buffer.duration / 4);
gain.gain.setValueAtTime(0.0, startTime);
gain.gain.linearRampToValueAtTime(1.0, startTime + fade);
gain.gain.setValueAtTime(1.0, Math.max(startTime + fade, endTime - fade));
gain.gain.linearRampToValueAtTime(0.0, endTime);
source.start(startTime);
nextTime = endTime;
}
function handlePayload(text) {
if (!text) return;
let payload;
try {
payload = JSON.parse(text);
} catch (e) {
return;
}
if (Array.isArray(payload)) {
for (const item of payload) {
handlePayloadObject(item);
}
return;
}
handlePayloadObject(payload);
}
function handlePayloadObject(payload) {
if (!payload) return;
if (payload.reset) {
lastIdx = -1;
lastValue = "";
if (audioCtx) {
audioCtx.close();
audioCtx = null;
}
return;
}
const idx = payload.idx ?? 0;
if (idx <= lastIdx) return;
lastIdx = idx;
const sr = payload.sr || 24000;
const samples = decodeBase64ToFloat32(payload.data);
playChunk(samples, sr, idx);
}
function hookField(field) {
if (!field || field === boundField) return;
boundField = field;
const proto = field.tagName === "TEXTAREA" ? HTMLTextAreaElement.prototype : HTMLInputElement.prototype;
const desc = Object.getOwnPropertyDescriptor(proto, "value");
if (!desc || !desc.get || !desc.set) {
usingSetterHook = false;
return;
}
usingSetterHook = true;
const nativeGet = desc.get;
const nativeSet = desc.set;
Object.defineProperty(field, "value", {
configurable: true,
get() {
return nativeGet.call(field);
},
set(v) {
nativeSet.call(field, v);
if (v && v !== lastValue) {
lastValue = v;
handlePayload(v);
}
},
});
const initial = field.value;
if (initial && initial !== lastValue) {
lastValue = initial;
handlePayload(initial);
}
}
function pollField() {
const field = document.querySelector(`#${elemId} textarea, #${elemId} input`);
if (!field) {
boundField = null;
usingSetterHook = false;
setTimeout(pollField, 300);
return;
}
if (field !== boundField) {
hookField(field);
}
setTimeout(pollField, 300);
}
function pollValue() {
if (usingSetterHook) {
setTimeout(pollValue, 500);
return;
}
const field = document.querySelector(`#${elemId} textarea, #${elemId} input`);
if (!field) {
setTimeout(pollValue, 300);
return;
}
const value = field.value;
if (value && value !== lastValue) {
lastValue = value;
handlePayload(value);
}
setTimeout(pollValue, 40);
}
function tryUnlockAudio() {
if (!audioCtx) {
audioCtx = new (window.AudioContext || window.webkitAudioContext)();
}
if (audioCtx.state === "suspended") {
audioCtx.resume();
}
}
document.addEventListener("click", (event) => {
const btn = event.target.closest("#tts_generate");
if (btn) {
tryUnlockAudio();
}
});
pollField();
pollValue();
"""
def _build_demo(
args: argparse.Namespace,
tts_demo: StreamingTTSDemo,
warmup_manager: WarmupManager,
):
initial_warmup_snapshot = warmup_manager.snapshot()
with gr.Blocks(title="MossTTSRealtime") as demo:
gr.Markdown("MossTTSRealtime demo")
gr.HTML(STREAM_PLAYER_HTML, js_on_load=STREAM_PLAYER_JS)
with gr.Row():
with gr.Column():
assistant_text = gr.Textbox(label="Assistant Text", lines=6)
prompt_audio = gr.Audio(label="Prompt WAV (optional)", type="filepath")
with gr.Accordion("User Input Options", open=False):
user_text = gr.Textbox(label="User Text(optional)", lines=2)
user_audio = gr.Audio(label="User WAV (optional)", type="filepath")
use_default_prompt = gr.Checkbox(label="Use Default Prompt WAV (fallback)", value=False)
use_default_user = gr.Checkbox(label="Use Default User WAV (fallback)", value=False)
with gr.Accordion("Generation Options", open=False):
temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.6, step=0.05, label="Top P")
top_k = gr.Slider(1, 100, value=30, step=1, label="Top K")
repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.05, label="Repetition Penalty")
repetition_window = gr.Slider(
1, 200, value=DEFAULT_REPETITION_WINDOW, step=1, label="Repetition Window"
)
do_sample = gr.Checkbox(label="Do Sample", value=True)
max_length = gr.Slider(100, 10000, value=2000, step=10, label="Max Length")
seed = gr.Number(value=0, precision=0, label="Seed (0 for random)")
with gr.Accordion("Streaming Options", open=False):
stream_text_chunk_tokens = gr.Slider(1, 64, value=12, step=1, label="Text Chunk Tokens")
stream_input_delay = gr.Slider(0.0, 0.5, value=0.0, step=0.05, label="Input Delay (s)")
stream_decode_chunk_frames = gr.Slider(1, 20, value=6, step=1, label="Decode Chunk Frames")
stream_decode_overlap_frames = gr.Slider(0, 10, value=0, step=1, label="Decode Overlap Frames")
chunk_duration = gr.Slider(0.08, 4.0, value=0.96, step=0.08, label="Codec Chunk Duration (s)")
stream_prebuffer_seconds = gr.Slider(0.0, 20.0, value=0.0, step=0.05, label="Initial Buffer (s)")
run_btn = gr.Button(
"Generate" if initial_warmup_snapshot.ready else "Warming Up...",
elem_id="tts_generate",
interactive=initial_warmup_snapshot.ready,
)
with gr.Column():
stream_data = gr.Textbox(label="PCM Stream (JSON)", elem_id="pcm_stream", interactive=False, lines=6)
output_audio = gr.Audio(label="Final Audio", type="numpy")
initial_status = _status_from_snapshot(initial_warmup_snapshot)
status = gr.Textbox(label="Status", lines=3, value=initial_status)
warmup_timer = gr.Timer(value=WARMUP_POLL_INTERVAL_SECONDS, active=not initial_warmup_snapshot.ready)
def _poll_warmup_state():
snapshot = warmup_manager.snapshot()
return (
_warmup_button_update(snapshot),
_warmup_status_update(snapshot),
_warmup_timer_update(snapshot),
)
@spaces.GPU
def _on_generate(
user_text_value,
assistant_text_value,
prompt_audio_value,
user_audio_value,
use_default_prompt_value,
use_default_user_value,
temperature_value,
top_p_value,
top_k_value,
repetition_penalty_value,
repetition_window_value,
do_sample_value,
max_length_value,
seed_value,
stream_text_chunk_tokens_value,
stream_input_delay_value,
stream_decode_chunk_frames_value,
stream_decode_overlap_frames_value,
chunk_duration_value,
stream_prebuffer_seconds_value,
):
try:
started_at = time.monotonic()
full_chunks: list[np.ndarray] = []
first_audio_time: float | None = None
sample_rate = SAMPLE_RATE
rtf_logger: StreamRTFLogger | None = None
request = _build_request(
args,
user_text=user_text_value,
assistant_text=assistant_text_value,
prompt_audio=prompt_audio_value,
user_audio=user_audio_value,
use_default_prompt=bool(use_default_prompt_value),
use_default_user=bool(use_default_user_value),
temperature=float(temperature_value),
top_p=float(top_p_value),
top_k=int(top_k_value),
repetition_penalty=float(repetition_penalty_value),
repetition_window=int(repetition_window_value),
do_sample=bool(do_sample_value),
max_length=int(max_length_value),
seed=seed_value,
text_chunk_tokens=int(stream_text_chunk_tokens_value),
input_delay=float(stream_input_delay_value),
decode_chunk_frames=int(stream_decode_chunk_frames_value),
decode_overlap_frames=int(stream_decode_overlap_frames_value),
chunk_duration=float(chunk_duration_value),
prebuffer_seconds=float(stream_prebuffer_seconds_value),
)
rtf_logger = StreamRTFLogger.create(request, started_at)
for event in tts_demo.run_stream(request):
if event.audio is None:
continue
sr, chunk = event.audio
chunk = np.asarray(chunk).reshape(-1)
if chunk.size == 0:
continue
full_chunks.append(chunk)
sample_rate = sr
if first_audio_time is None:
first_audio_time = time.monotonic()
if rtf_logger is not None:
rtf_logger.log_chunk(
event_message=event.message,
sample_rate=sr,
chunk=chunk,
first_audio_time=first_audio_time,
)
if full_chunks:
full_audio = np.concatenate(full_chunks)
if rtf_logger is not None:
rtf_logger.log_completion(first_audio_time=first_audio_time)
done_msg = _format_completion_status(
len(full_chunks),
sample_rate,
full_audio,
started_at,
first_audio_time,
)
return "", (sample_rate, full_audio), done_msg
if rtf_logger is not None:
rtf_logger.log_no_audio()
return "", None, "Done | no audio chunks emitted"
except Exception as exc:
import traceback
traceback.print_exc()
if rtf_logger is not None:
rtf_logger.log_error(exc, first_audio_time=first_audio_time)
return "", None, f"Error: {exc}"
run_btn.click(
_on_generate,
inputs=[
user_text,
assistant_text,
prompt_audio,
user_audio,
use_default_prompt,
use_default_user,
temperature,
top_p,
top_k,
repetition_penalty,
repetition_window,
do_sample,
max_length,
seed,
stream_text_chunk_tokens,
stream_input_delay,
stream_decode_chunk_frames,
stream_decode_overlap_frames,
chunk_duration,
stream_prebuffer_seconds,
],
outputs=[stream_data, output_audio, status],
)
demo.load(
_poll_warmup_state,
outputs=[run_btn, status, warmup_timer],
queue=False,
show_progress="hidden",
)
warmup_timer.tick(
_poll_warmup_state,
outputs=[run_btn, status, warmup_timer],
queue=False,
show_progress="hidden",
)
return demo
def main():
parser = argparse.ArgumentParser(description="MossTTSRealtime streaming TTS Gradio demo")
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
parser.add_argument("--tokenizer_path", type=str, default=TOKENIZER_PATH)
parser.add_argument("--codec_model_path", type=str, default=CODEC_MODEL_PATH)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument(
"--attn_implementation",
type=str,
default="sdpa",
choices=["sdpa", "flash_attention_2", "eager", "none"],
)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
tts_demo = StreamingTTSDemo()
warmup_manager = WarmupManager(
tts_demo,
BackendPaths(
model_path=args.model_path,
tokenizer_path=args.tokenizer_path,
codec_model_path=args.codec_model_path,
device_str=args.device,
attn_impl=args.attn_implementation,
),
)
# ZeroGPU: do not run startup warmup, because it would initialize CUDA
# in a background thread outside @spaces.GPU.
# warmup_manager.start()
demo = _build_demo(args, tts_demo, warmup_manager)
demo.queue(max_size=10, default_concurrency_limit=1).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
ssr_mode=False,
)
if __name__ == "__main__":
main()