|
|
|
|
|
""" |
|
|
BeigeTTS - Streaming Inference |
|
|
Real-time streaming text-to-speech with chunk-based generation |
|
|
Research release derived from BlandAI's Khaki TTS system |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import soundfile as sf |
|
|
from neucodec import NeuCodec |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import threading |
|
|
import queue |
|
|
import time |
|
|
from typing import Optional, Generator, Tuple |
|
|
import pyaudio |
|
|
import argparse |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamingConfig: |
|
|
"""Configuration for streaming TTS""" |
|
|
|
|
|
AUDIO_START_TOKEN = 262145 |
|
|
AUDIO_END_TOKEN = 262146 |
|
|
NEUCODEC_BASE_OFFSET = 262154 |
|
|
NEUCODEC_VOCABULARY_SIZE = 65536 |
|
|
AUDIO_TOKEN_MIN = NEUCODEC_BASE_OFFSET |
|
|
AUDIO_TOKEN_MAX = NEUCODEC_BASE_OFFSET + NEUCODEC_VOCABULARY_SIZE |
|
|
|
|
|
|
|
|
CHUNK_SIZE = 50 |
|
|
BUFFER_SIZE = 3 |
|
|
SAMPLE_RATE = 24000 |
|
|
|
|
|
|
|
|
DEFAULT_TEMPERATURE = 0.1 |
|
|
DEFAULT_TOP_P = 0.97 |
|
|
MAX_TOTAL_TOKENS = 1000 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamingBeigeTTS: |
|
|
"""Streaming BeigeTTS engine with chunk-based generation |
|
|
|
|
|
Note: Production Khaki system achieves <50ms latency with |
|
|
advanced buffering and predictive generation. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str = "BlandAI/BeigeTTS"): |
|
|
"""Initialize streaming TTS engine""" |
|
|
self.config = StreamingConfig() |
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
print("Initializing BeigeTTS streaming engine (research release)...") |
|
|
print("Note: Khaki production system supports <50ms latency and 57 languages") |
|
|
|
|
|
|
|
|
self._load_models(model_path) |
|
|
|
|
|
|
|
|
self.audio_queue = queue.Queue(maxsize=self.config.BUFFER_SIZE) |
|
|
self.generation_complete = threading.Event() |
|
|
self.stop_generation = threading.Event() |
|
|
|
|
|
|
|
|
self.audio_interface = None |
|
|
self.audio_stream = None |
|
|
|
|
|
def _load_models(self, model_path: str): |
|
|
"""Load BeigeTTS and NeuCodec models""" |
|
|
print("Loading BeigeTTS model...") |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.float16 if self.device.type == "cuda" else torch.float32, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self.model.eval() |
|
|
|
|
|
print("Loading tokenizer...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
print("Loading NeuCodec...") |
|
|
self.neucodec = NeuCodec.from_pretrained("neuphonic/neucodec") |
|
|
self.neucodec.eval() |
|
|
if self.device.type == "cuda": |
|
|
self.neucodec = self.neucodec.to(self.device) |
|
|
|
|
|
def stream_synthesize( |
|
|
self, |
|
|
text: str, |
|
|
temperature: float = None, |
|
|
top_p: float = None, |
|
|
play_audio: bool = False |
|
|
) -> Generator[np.ndarray, None, None]: |
|
|
"""Stream synthesize speech from text |
|
|
|
|
|
Args: |
|
|
text: Input text to synthesize |
|
|
temperature: Sampling temperature |
|
|
top_p: Nucleus sampling parameter |
|
|
play_audio: Whether to play audio in real-time |
|
|
|
|
|
Yields: |
|
|
Audio chunks as numpy arrays |
|
|
""" |
|
|
temperature = temperature or self.config.DEFAULT_TEMPERATURE |
|
|
top_p = top_p or self.config.DEFAULT_TOP_P |
|
|
|
|
|
|
|
|
self.generation_complete.clear() |
|
|
self.stop_generation.clear() |
|
|
|
|
|
|
|
|
generation_thread = threading.Thread( |
|
|
target=self._generation_worker, |
|
|
args=(text, temperature, top_p) |
|
|
) |
|
|
generation_thread.start() |
|
|
|
|
|
|
|
|
if play_audio: |
|
|
playback_thread = threading.Thread(target=self._playback_worker) |
|
|
playback_thread.start() |
|
|
|
|
|
|
|
|
try: |
|
|
while not self.generation_complete.is_set() or not self.audio_queue.empty(): |
|
|
try: |
|
|
audio_chunk = self.audio_queue.get(timeout=0.1) |
|
|
yield audio_chunk |
|
|
except queue.Empty: |
|
|
continue |
|
|
finally: |
|
|
|
|
|
self.stop_generation.set() |
|
|
generation_thread.join() |
|
|
if play_audio: |
|
|
playback_thread.join() |
|
|
|
|
|
def _generation_worker(self, text: str, temperature: float, top_p: float): |
|
|
"""Worker thread for token generation""" |
|
|
try: |
|
|
|
|
|
prompt = f"<start_of_turn>user\n{text}<end_of_turn>\n<start_of_turn>model\n<start_of_speech>" |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(prompt, return_tensors="pt") |
|
|
input_ids = inputs.input_ids.to(self.model.device) |
|
|
|
|
|
|
|
|
past_key_values = None |
|
|
current_ids = input_ids |
|
|
audio_token_buffer = [] |
|
|
total_generated = 0 |
|
|
|
|
|
print("Starting streaming generation...") |
|
|
print("(BeigeTTS research mode - Khaki production offers superior latency)") |
|
|
|
|
|
with torch.no_grad(): |
|
|
while total_generated < self.config.MAX_TOTAL_TOKENS and not self.stop_generation.is_set(): |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=current_ids, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=True |
|
|
) |
|
|
|
|
|
logits = outputs.logits[:, -1, :] |
|
|
past_key_values = outputs.past_key_values |
|
|
|
|
|
|
|
|
if temperature > 0: |
|
|
probs = torch.nn.functional.softmax(logits / temperature, dim=-1) |
|
|
|
|
|
|
|
|
if top_p < 1.0: |
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
|
|
sorted_indices_to_remove[:, 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter( |
|
|
1, sorted_indices, sorted_indices_to_remove |
|
|
) |
|
|
probs[indices_to_remove] = 0 |
|
|
probs = probs / probs.sum(dim=-1, keepdim=True) |
|
|
|
|
|
next_token = torch.multinomial(probs, num_samples=1) |
|
|
else: |
|
|
next_token = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
|
|
token_id = next_token.item() |
|
|
|
|
|
|
|
|
if token_id == self.config.AUDIO_END_TOKEN or token_id == self.tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
|
|
|
if self.config.AUDIO_TOKEN_MIN <= token_id < self.config.AUDIO_TOKEN_MAX: |
|
|
audio_token_buffer.append(token_id - self.config.NEUCODEC_BASE_OFFSET) |
|
|
|
|
|
|
|
|
if len(audio_token_buffer) >= self.config.CHUNK_SIZE: |
|
|
audio_chunk = self._decode_chunk(audio_token_buffer[:self.config.CHUNK_SIZE]) |
|
|
self.audio_queue.put(audio_chunk) |
|
|
audio_token_buffer = audio_token_buffer[self.config.CHUNK_SIZE:] |
|
|
print(f"Streamed chunk {total_generated // self.config.CHUNK_SIZE}") |
|
|
|
|
|
|
|
|
current_ids = next_token |
|
|
total_generated += 1 |
|
|
|
|
|
|
|
|
if audio_token_buffer: |
|
|
audio_chunk = self._decode_chunk(audio_token_buffer) |
|
|
self.audio_queue.put(audio_chunk) |
|
|
|
|
|
print(f"Generation complete. Total tokens: {total_generated}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {e}") |
|
|
finally: |
|
|
self.generation_complete.set() |
|
|
|
|
|
def _decode_chunk(self, audio_tokens: list) -> np.ndarray: |
|
|
"""Decode a chunk of audio tokens""" |
|
|
|
|
|
audio_array = np.array(audio_tokens, dtype=np.int32) |
|
|
audio_array = np.clip(audio_array, 0, self.config.NEUCODEC_VOCABULARY_SIZE - 1) |
|
|
|
|
|
|
|
|
fsq_codes = torch.tensor(audio_array, dtype=torch.long) |
|
|
fsq_codes = fsq_codes.unsqueeze(0).unsqueeze(1) |
|
|
|
|
|
if self.device.type == "cuda": |
|
|
fsq_codes = fsq_codes.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
wav = self.neucodec.decode_code(fsq_codes).cpu() |
|
|
|
|
|
|
|
|
if wav.dim() == 3: |
|
|
wav = wav[0, 0] |
|
|
elif wav.dim() == 2: |
|
|
wav = wav[0] |
|
|
|
|
|
wav = wav.numpy() |
|
|
|
|
|
|
|
|
if np.abs(wav).max() > 0: |
|
|
wav = wav / np.abs(wav).max() * 0.95 |
|
|
|
|
|
return wav |
|
|
|
|
|
def _playback_worker(self): |
|
|
"""Worker thread for audio playback""" |
|
|
try: |
|
|
|
|
|
self.audio_interface = pyaudio.PyAudio() |
|
|
self.audio_stream = self.audio_interface.open( |
|
|
format=pyaudio.paFloat32, |
|
|
channels=1, |
|
|
rate=self.config.SAMPLE_RATE, |
|
|
output=True |
|
|
) |
|
|
|
|
|
print("Starting audio playback...") |
|
|
|
|
|
while not self.generation_complete.is_set() or not self.audio_queue.empty(): |
|
|
try: |
|
|
audio_chunk = self.audio_queue.get(timeout=0.1) |
|
|
self.audio_stream.write(audio_chunk.astype(np.float32).tobytes()) |
|
|
except queue.Empty: |
|
|
continue |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Playback error: {e}") |
|
|
finally: |
|
|
if self.audio_stream: |
|
|
self.audio_stream.stop_stream() |
|
|
self.audio_stream.close() |
|
|
if self.audio_interface: |
|
|
self.audio_interface.terminate() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AdaptiveBeigeTTS(StreamingBeigeTTS): |
|
|
"""Advanced streaming with adaptive chunk sizing |
|
|
|
|
|
Note: This demonstrates research concepts. Production Khaki system |
|
|
includes predictive buffering, voice activity detection, and |
|
|
neural vocoder post-processing for superior quality. |
|
|
""" |
|
|
|
|
|
def __init__(self, model_path: str = "BlandAI/BeigeTTS"): |
|
|
super().__init__(model_path) |
|
|
|
|
|
|
|
|
self.min_chunk_size = 25 |
|
|
self.max_chunk_size = 100 |
|
|
self.target_latency_ms = 200 |
|
|
self.generation_speed_ema = 0.0 |
|
|
self.ema_alpha = 0.1 |
|
|
|
|
|
def adaptive_stream( |
|
|
self, |
|
|
text: str, |
|
|
quality_priority: float = 0.5 |
|
|
) -> Generator[np.ndarray, None, None]: |
|
|
"""Stream with adaptive chunk sizing based on generation speed |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
quality_priority: Balance between speed and quality (0-1) |
|
|
|
|
|
Yields: |
|
|
Adaptively sized audio chunks |
|
|
""" |
|
|
|
|
|
temperature = 0.05 + (0.15 * quality_priority) |
|
|
top_p = 0.9 + (0.08 * quality_priority) |
|
|
|
|
|
|
|
|
chunk_size = int(self.min_chunk_size + |
|
|
(self.max_chunk_size - self.min_chunk_size) * quality_priority) |
|
|
|
|
|
print(f"Adaptive streaming: chunk_size={chunk_size}, temp={temperature:.2f}, top_p={top_p:.2f}") |
|
|
print("(Khaki production includes neural enhancement for optimal quality)") |
|
|
|
|
|
|
|
|
original_chunk = self.config.CHUNK_SIZE |
|
|
self.config.CHUNK_SIZE = chunk_size |
|
|
|
|
|
try: |
|
|
yield from self.stream_synthesize(text, temperature, top_p) |
|
|
finally: |
|
|
self.config.CHUNK_SIZE = original_chunk |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="BeigeTTS Streaming (Research Release)") |
|
|
parser.add_argument("text", type=str, help="Text to synthesize") |
|
|
parser.add_argument("-o", "--output", type=str, help="Output WAV file (optional)") |
|
|
parser.add_argument("-m", "--model", type=str, default="BlandAI/BeigeTTS", help="Model path") |
|
|
parser.add_argument("--play", action="store_true", help="Play audio in real-time") |
|
|
parser.add_argument("--adaptive", action="store_true", help="Use adaptive streaming") |
|
|
parser.add_argument("--quality", type=float, default=0.5, help="Quality priority (0-1)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("BeigeTTS Streaming - Research Release") |
|
|
print("Production Khaki TTS: <50ms latency, 57 languages, unlimited duration") |
|
|
print("-" * 60) |
|
|
|
|
|
|
|
|
if args.adaptive: |
|
|
tts = AdaptiveBeigeTTS(model_path=args.model) |
|
|
stream_gen = tts.adaptive_stream(args.text, quality_priority=args.quality) |
|
|
else: |
|
|
tts = StreamingBeigeTTS(model_path=args.model) |
|
|
stream_gen = tts.stream_synthesize(args.text, play_audio=args.play) |
|
|
|
|
|
|
|
|
audio_chunks = [] |
|
|
print("Streaming audio generation...") |
|
|
|
|
|
for i, chunk in enumerate(stream_gen): |
|
|
audio_chunks.append(chunk) |
|
|
print(f" Received chunk {i+1} ({len(chunk)/tts.config.SAMPLE_RATE:.2f}s)") |
|
|
|
|
|
|
|
|
if args.output and audio_chunks: |
|
|
full_audio = np.concatenate(audio_chunks) |
|
|
sf.write(args.output, full_audio, tts.config.SAMPLE_RATE) |
|
|
duration = len(full_audio) / tts.config.SAMPLE_RATE |
|
|
print(f"\nβ
Saved {duration:.1f}s of audio to {args.output}") |
|
|
|
|
|
print("\n⨠Streaming complete!") |
|
|
print("For commercial use and advanced features, contact partnerships@bland.ai") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|