File size: 5,444 Bytes
a316f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import time
import logging
import requests
import numpy as np
import soundfile as sf
from typing import Optional, Tuple, Generator

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_API_URL = "https://droolingpanda-dia-tts-server.hf.space"
DEFAULT_MODEL = "dia-1.6b"

# Global client instance (lazy loaded)
_client = None


def _get_client():
    """Lazy-load the Dia Space client to avoid loading it until needed"""
    global _client
    if _client is None:
        logger.info("Loading Dia Space client...")
        try:
            # Import requests if not already imported
            import requests
            
            # Initialize the client (just a session for now)
            logger.info("Initializing Dia Space client")
            _client = requests.Session()
            
            # Test connection to the API
            response = _client.get(f"{DEFAULT_API_URL}/docs")
            if response.status_code == 200:
                logger.info("Dia Space client loaded successfully")
                logger.info(f"Client type: {type(_client).__name__}")
            else:
                logger.warning(f"Dia Space API returned status code {response.status_code}")
        except ImportError as import_err:
            logger.error(f"Import error loading Dia Space client: {import_err}")
            logger.error("This may indicate missing dependencies")
            raise
        except Exception as e:
            logger.error(f"Error loading Dia Space client: {e}", exc_info=True)
            logger.error(f"Error type: {type(e).__name__}")
            raise
    return _client


def generate_speech(text: str, language: str = "zh", voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> str:
    """Public interface for TTS generation using Dia Space API
    
    This is a legacy function maintained for backward compatibility.
    New code should use the factory pattern implementation directly.
    
    Args:
        text (str): Input text to synthesize
        language (str): Language code (not used in Dia Space, kept for API compatibility)
        voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
        response_format (str): Audio format ('wav', 'mp3', 'opus')
        speed (float): Speech speed multiplier
        
    Returns:
        str: Path to the generated audio file
    """
    logger.info(f"Legacy Dia Space generate_speech called with text length: {len(text)}")
    
    # Use the new implementation via factory pattern
    from utils.tts_engines import DiaSpaceTTSEngine
    
    try:
        # Create a Dia Space engine and generate speech
        dia_space_engine = DiaSpaceTTSEngine(language)
        return dia_space_engine.generate_speech(text, voice, speed, response_format)
    except Exception as e:
        logger.error(f"Error in legacy Dia Space generate_speech: {str(e)}", exc_info=True)
        # Fall back to dummy TTS
        from utils.tts_base import DummyTTSEngine
        dummy_engine = DummyTTSEngine()
        return dummy_engine.generate_speech(text)


def _create_output_dir() -> str:
    """Create output directory for audio files
    
    Returns:
        str: Path to the output directory
    """
    output_dir = "temp/outputs"
    os.makedirs(output_dir, exist_ok=True)
    return output_dir


def _generate_output_path(prefix: str = "output", extension: str = "wav") -> str:
    """Generate a unique output path for audio files
    
    Args:
        prefix (str): Prefix for the output filename
        extension (str): File extension for the output file
        
    Returns:
        str: Path to the output file
    """
    output_dir = _create_output_dir()
    timestamp = int(time.time())
    return f"{output_dir}/{prefix}_{timestamp}.{extension}"


def _call_dia_api(text: str, voice: str = "S1", response_format: str = "wav", speed: float = 1.0) -> bytes:
    """Call the Dia Space API to generate speech
    
    Args:
        text (str): Input text to synthesize
        voice (str): Voice mode to use ('S1', 'S2', 'dialogue', or filename for clone)
        response_format (str): Audio format ('wav', 'mp3', 'opus')
        speed (float): Speech speed multiplier
        
    Returns:
        bytes: Audio data
    """
    client = _get_client()
    
    # Prepare the request payload
    payload = {
        "model": DEFAULT_MODEL,
        "input": text,
        "voice": voice,
        "response_format": response_format,
        "speed": speed
    }
    
    # Make the API request
    logger.info(f"Calling Dia Space API with voice: {voice}, format: {response_format}, speed: {speed}")
    try:
        response = client.post(
            f"{DEFAULT_API_URL}/v1/audio/speech",
            json=payload,
            headers={"Content-Type": "application/json"}
        )
        
        # Check for successful response
        if response.status_code == 200:
            logger.info("Dia Space API call successful")
            return response.content
        else:
            logger.error(f"Dia Space API returned error: {response.status_code}")
            logger.error(f"Response: {response.text}")
            raise Exception(f"Dia Space API error: {response.status_code}")
    except Exception as e:
        logger.error(f"Error calling Dia Space API: {str(e)}", exc_info=True)
        raise