File size: 5,296 Bytes
030c851
 
 
 
 
 
 
 
 
 
 
 
cb90410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
030c851
 
 
 
 
 
 
 
cb90410
030c851
 
cb90410
 
 
 
 
 
030c851
 
 
91223c9
 
 
 
 
 
 
 
 
 
 
 
030c851
91223c9
 
 
 
2d176f4
 
 
 
 
91223c9
 
 
 
 
 
 
 
030c851
 
91223c9
 
030c851
 
 
c549dab
030c851
 
 
3ed3b5a
 
 
030c851
 
 
 
 
 
 
3ed3b5a
4a9bb1a
cb90410
 
 
 
 
 
4a9bb1a
cb90410
4a9bb1a
cb90410
 
 
3ed3b5a
 
 
cb90410
 
 
 
 
 
 
 
3ed3b5a
 
4a9bb1a
3ed3b5a
cb90410
3ed3b5a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import os
import time
import logging
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Flag to track Dia availability
DIA_AVAILABLE = False

# Try to import required dependencies
try:
    import torch
    # Try to import Dia, which will try to import dac
    try:
        from dia.model import Dia
        DIA_AVAILABLE = True
        logger.info("Dia TTS engine is available")
    except ModuleNotFoundError as e:
        if "dac" in str(e):
            logger.warning("Dia TTS engine is not available due to missing 'dac' module")
        else:
            logger.warning(f"Dia TTS engine is not available: {str(e)}")
        DIA_AVAILABLE = False
except ImportError:
    logger.warning("Torch not available, Dia TTS engine cannot be used")
    DIA_AVAILABLE = False

# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"

# Global model instance (lazy loaded)
_model = None


def _get_model():
    """Lazy-load the Dia model to avoid loading it until needed"""
    global _model
    
    # Check if Dia is available before attempting to load
    if not DIA_AVAILABLE:
        logger.warning("Dia is not available, cannot load model")
        raise ImportError("Dia module is not available")
        
    if _model is None:
        logger.info("Loading Dia model...")
        try:
            # Check if torch is available with correct version
            logger.info(f"PyTorch version: {torch.__version__}")
            logger.info(f"CUDA available: {torch.cuda.is_available()}")
            if torch.cuda.is_available():
                logger.info(f"CUDA version: {torch.version.cuda}")
                logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
            
            # Check if model path exists
            logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
            
            # Load the model with detailed logging
            logger.info("Initializing Dia model...")
            _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
            
            # Log model details
            logger.info(f"Dia model loaded successfully")
            logger.info(f"Model type: {type(_model).__name__}")
            # Check if model has parameters method (PyTorch models do, but Dia might not)
            if hasattr(_model, 'parameters'):
                logger.info(f"Model device: {next(_model.parameters()).device}")
            else:
                logger.info("Model device: Device information not available for Dia model")
        except ImportError as import_err:
            logger.error(f"Import error loading Dia model: {import_err}")
            logger.error(f"This may indicate missing dependencies")
            raise
        except FileNotFoundError as file_err:
            logger.error(f"File not found error loading Dia model: {file_err}")
            logger.error(f"Model path may be incorrect or inaccessible")
            raise
        except Exception as e:
            logger.error(f"Error loading Dia model: {e}", exc_info=True)
            logger.error(f"Error type: {type(e).__name__}")
            logger.error(f"This may indicate incompatible versions or missing CUDA support")
            raise
    return _model


def generate_speech(text: str, language: str = "zh") -> str:
    """Public interface for TTS generation using Dia model
    
    This is a legacy function maintained for backward compatibility.
    New code should use the factory pattern implementation directly.
    
    Args:
        text (str): Input text to synthesize
        language (str): Language code (not used in Dia model, kept for API compatibility)
        
    Returns:
        str: Path to the generated audio file
    """
    logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
    
    # Check if Dia is available
    if not DIA_AVAILABLE:
        logger.warning("Dia is not available, falling back to dummy TTS engine")
        from utils.tts_base import DummyTTSEngine
        dummy_engine = DummyTTSEngine(language)
        return dummy_engine.generate_speech(text)
    
    # Use the new implementation via factory pattern
    try:
        # Import here to avoid circular imports
        from utils.tts_engines import DiaTTSEngine
        
        # Create a Dia engine and generate speech
        dia_engine = DiaTTSEngine(language)
        return dia_engine.generate_speech(text)
    except ModuleNotFoundError as e:
        logger.error(f"Module not found error in Dia generate_speech: {str(e)}")
        if "dac" in str(e):
            logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
        # Fall back to dummy TTS
        from utils.tts_base import DummyTTSEngine
        dummy_engine = DummyTTSEngine(language)
        return dummy_engine.generate_speech(text)
    except Exception as e:
        logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
        # Fall back to dummy TTS
        from utils.tts_base import DummyTTSEngine
        dummy_engine = DummyTTSEngine(language)
        return dummy_engine.generate_speech(text)