klentyboopathi's picture
Intital commit
8362005
#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
# This code originally written by Marmik Pandya (marmikcfc - github.com/marmikcfc)
import numpy as np
from typing import AsyncGenerator, List, Optional, Union
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.tts_service import TTSService
from pipecat.transcriptions.language import Language
# load Kokoro from kokoro-onnx
try:
from kokoro_onnx import Kokoro
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Kokoro, you need to `pip install kokoro-onnx`. Also, download the model files from the Kokoro repository."
)
raise Exception(f"Missing module: {e}")
def language_to_kokoro_language(language: Language) -> Optional[str]:
"""Convert pipecat Language to Kokoro language code."""
BASE_LANGUAGES = {
Language.EN: "en-us",
Language.FR: "fr-fr",
Language.IT: "it",
Language.JA: "ja",
Language.CMN: "cmn"
# Add more language mappings as supported by Kokoro
}
result = BASE_LANGUAGES.get(language)
# If not found in base languages, try to find the base language from a variant
if not result:
lang_str = str(language.value)
base_code = lang_str.split("-")[0].lower()
# Look up the base code in our supported languages
result = f"{base_code}-us" if base_code in ["en"] else None
return result
class KokoroTTSService(TTSService):
"""Text-to-Speech service using Kokoro for on-device TTS.
This service uses Kokoro to generate speech without requiring external API connections.
"""
class InputParams(BaseModel):
"""Configuration parameters for Kokoro TTS service."""
language: Optional[Language] = Language.EN
speed: Optional[float] = 1.0
def __init__(
self,
*,
model_path: str,
voices_path: str,
voice_id: str = "af_sarah",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
**kwargs,
):
"""Initialize Kokoro TTS service.
Args:
model_path: Path to the Kokoro ONNX model file
voices_path: Path to the Kokoro voices file
voice_id: ID of the voice to use
sample_rate: Output audio sample rate
params: Additional configuration parameters
"""
super().__init__(sample_rate=sample_rate, **kwargs)
logger.info(f"Initializing Kokoro TTS service with model_path: {model_path} and voices_path: {voices_path}")
self._kokoro = Kokoro(model_path, voices_path)
logger.info(f"Kokoro initialized")
self._settings = {
"language": self.language_to_service_language(params.language)
if params.language
else "en-us",
"speed": params.speed,
}
self.set_voice(voice_id) # Presumably this sets self._voice_id
logger.info("Kokoro TTS service initialized")
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
"""Convert pipecat language to Kokoro language code."""
return language_to_kokoro_language(language)
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
"""Generate speech from text using Kokoro in a streaming fashion.
Args:
text: The text to convert to speech
Yields:
Frames containing audio data and status information.
"""
logger.debug(f"Generating TTS: [{text}]")
try:
await self.start_ttfb_metrics()
yield TTSStartedFrame()
# Use Kokoro's streaming mode. The create_stream method is assumed to return
# an async generator that yields (samples, sample_rate) tuples, where samples is a numpy array.
logger.info(f"Creating stream")
stream = self._kokoro.create_stream(
text,
voice=self._voice_id,
speed=self._settings["speed"],
lang=self._settings["language"],
)
await self.start_tts_usage_metrics(text)
started = False
async for samples, sample_rate in stream:
if not started:
started = True
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
logger.info(f"Started streaming")
# Convert the float32 samples (assumed in the range [-1, 1]) to int16 PCM format
samples_int16 = (samples * 32767).astype(np.int16)
yield TTSAudioRawFrame(
audio=samples_int16.tobytes(),
sample_rate=sample_rate,
num_channels=1,
)
yield TTSStoppedFrame()
except Exception as e:
logger.error(f"{self} exception: {e}")
yield ErrorFrame(f"Error generating audio: {str(e)}")