teachingAssistant / utils /tts_dia_space.py
Michael Hu
handle dia model not available
a316f58
import os
import time
import logging
import requests
import numpy as np
import soundfile as sf
from typing import Optional, Tuple, Generator
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_API_URL = "https://droolingpanda-dia-tts-server.hf.space"
DEFAULT_MODEL = "dia-1.6b"
# Global client instance (lazy loaded)
_client = None
def _get_client():
"""Lazy-load the Dia Space client to avoid loading it until needed"""
global _client
if _client is None:
logger.info("Loading Dia Space client...")
try:
# Import requests if not already imported
import requests
# Initialize the client (just a session for now)
logger.info("Initializing Dia Space client")
_client = requests.Session()
# Test connection to the API
response = _client.get(f"{DEFAULT_API_URL}/docs")
if response.status_code == 200:
logger.info("Dia Space client loaded successfully")
logger.info(f"Client type: {type(_client).__name__}")
else:
logger.warning(f"Dia Space API returned status code {response.status_code}")
except ImportError as import_err:
logger.error(f"Import error loading Dia Space client: {import_err}")
logger.error("This may indicate missing dependencies")
raise
except Exception as e:
logger.error(f"Error loading Dia Space client: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
raise
return _client
def generate_speech(text: str, language: str = "zh", voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> str:
"""Public interface for TTS generation using Dia Space API
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia Space, kept for API compatibility)
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
response_format (str): Audio format ('wav', 'mp3', 'opus')
speed (float): Speech speed multiplier
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Dia Space generate_speech called with text length: {len(text)}")
# Use the new implementation via factory pattern
from utils.tts_engines import DiaSpaceTTSEngine
try:
# Create a Dia Space engine and generate speech
dia_space_engine = DiaSpaceTTSEngine(language)
return dia_space_engine.generate_speech(text, voice, speed, response_format)
except Exception as e:
logger.error(f"Error in legacy Dia Space generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine()
return dummy_engine.generate_speech(text)
def _create_output_dir() -> str:
"""Create output directory for audio files
Returns:
str: Path to the output directory
"""
output_dir = "temp/outputs"
os.makedirs(output_dir, exist_ok=True)
return output_dir
def _generate_output_path(prefix: str = "output", extension: str = "wav") -> str:
"""Generate a unique output path for audio files
Args:
prefix (str): Prefix for the output filename
extension (str): File extension for the output file
Returns:
str: Path to the output file
"""
output_dir = _create_output_dir()
timestamp = int(time.time())
return f"{output_dir}/{prefix}_{timestamp}.{extension}"
def _call_dia_api(text: str, voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> bytes:
"""Call the Dia Space API to generate speech
Args:
text (str): Input text to synthesize
voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
response_format (str): Audio format ('wav', 'mp3', 'opus')
speed (float): Speech speed multiplier
Returns:
bytes: Audio data
"""
client = _get_client()
# Prepare the request payload
payload = {
"model": DEFAULT_MODEL,
"input": text,
"voice": voice,
"response_format": response_format,
"speed": speed
}
# Make the API request
logger.info(f"Calling Dia Space API with voice: {voice}, format: {response_format}, speed: {speed}")
try:
response = client.post(
f"{DEFAULT_API_URL}/v1/audio/speech",
json=payload,
headers={"Content-Type": "application/json"}
)
# Check for successful response
if response.status_code == 200:
logger.info("Dia Space API call successful")
return response.content
else:
logger.error(f"Dia Space API returned error: {response.status_code}")
logger.error(f"Response: {response.text}")
raise Exception(f"Dia Space API error: {response.status_code}")
except Exception as e:
logger.error(f"Error calling Dia Space API: {str(e)}", exc_info=True)
raise