Michael Hu
fix circular dependency
aaa0814
raw
history blame
4.64 kB
import logging
from typing import Optional, Generator, Tuple, List, Dict, Any
import numpy as np
# Import the base class and dummy implementation
from utils.tts_base import TTSBase
from utils.tts_dummy import DummyTTS
# Import the specific TTS implementations
from utils.tts_kokoro import KokoroTTS, KOKORO_AVAILABLE
from utils.tts_dia import DiaTTS, DIA_AVAILABLE
from utils.tts_cosyvoice2 import CosyVoice2TTS, COSYVOICE2_AVAILABLE
# Configure logging
logger = logging.getLogger(__name__)
def get_available_engines() -> List[str]:
"""Get a list of available TTS engines
Returns:
List[str]: List of available engine names
"""
available = []
if KOKORO_AVAILABLE:
available.append('kokoro')
if DIA_AVAILABLE:
available.append('dia')
if COSYVOICE2_AVAILABLE:
available.append('cosyvoice2')
# Dummy is always available
available.append('dummy')
return available
def get_tts_engine(engine_type: Optional[str] = None, lang_code: str = 'z') -> TTSBase:
"""Get a TTS engine instance
Args:
engine_type (str, optional): Type of engine to create ('kokoro', 'dia', 'cosyvoice2', 'dummy')
If None, the best available engine will be used
lang_code (str): Language code for the engine
Returns:
TTSBase: An instance of a TTS engine
"""
# Get available engines
available_engines = get_available_engines()
logger.info(f"Available TTS engines: {available_engines}")
# If engine_type is specified, try to create that specific engine
if engine_type is not None:
if engine_type == 'kokoro' and KOKORO_AVAILABLE:
logger.info("Creating Kokoro TTS engine")
return KokoroTTS(lang_code)
elif engine_type == 'dia' and DIA_AVAILABLE:
logger.info("Creating Dia TTS engine")
return DiaTTS(lang_code)
elif engine_type == 'cosyvoice2' and COSYVOICE2_AVAILABLE:
logger.info("Creating CosyVoice2 TTS engine")
return CosyVoice2TTS(lang_code)
elif engine_type == 'dummy':
logger.info("Creating Dummy TTS engine")
return DummyTTS(lang_code)
else:
logger.warning(f"Requested engine '{engine_type}' is not available")
# If no specific engine is requested or the requested engine is not available,
# use the best available engine based on priority
priority_order = ['cosyvoice2', 'kokoro', 'dia', 'dummy']
for engine in priority_order:
if engine in available_engines:
logger.info(f"Using best available engine: {engine}")
if engine == 'kokoro':
return KokoroTTS(lang_code)
elif engine == 'dia':
return DiaTTS(lang_code)
elif engine == 'cosyvoice2':
return CosyVoice2TTS(lang_code)
elif engine == 'dummy':
return DummyTTS(lang_code)
# Fallback to dummy engine if no engines are available
logger.warning("No TTS engines available, falling back to dummy engine")
return DummyTTS(lang_code)
def generate_speech(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
voice: str = 'default', speed: float = 1.0) -> Optional[str]:
"""Generate speech using the specified or best available TTS engine
Args:
text (str): Input text to synthesize
engine_type (str, optional): Type of engine to use
lang_code (str): Language code
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Returns:
Optional[str]: Path to the generated audio file or None if generation fails
"""
engine = get_tts_engine(engine_type, lang_code)
return engine.generate_speech(text, voice, speed)
def generate_speech_stream(text: str, engine_type: Optional[str] = None, lang_code: str = 'z',
voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
"""Generate speech stream using the specified or best available TTS engine
Args:
text (str): Input text to synthesize
engine_type (str, optional): Type of engine to use
lang_code (str): Language code
voice (str): Voice ID to use
speed (float): Speech speed multiplier
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
engine = get_tts_engine(engine_type, lang_code)
yield from engine.generate_speech_stream(text, voice, speed)