import os import time import logging import numpy as np import soundfile as sf from pathlib import Path from typing import Optional # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Flag to track Dia availability DIA_AVAILABLE = False # Try to import required dependencies try: import torch # Try to import Dia, which will try to import dac try: from dia.model import Dia DIA_AVAILABLE = True logger.info("Dia TTS engine is available") except ModuleNotFoundError as e: if "dac" in str(e): logger.warning("Dia TTS engine is not available due to missing 'dac' module") else: logger.warning(f"Dia TTS engine is not available: {str(e)}") DIA_AVAILABLE = False except ImportError: logger.warning("Torch not available, Dia TTS engine cannot be used") DIA_AVAILABLE = False # Constants DEFAULT_SAMPLE_RATE = 44100 DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B" # Global model instance (lazy loaded) _model = None def _get_model(): """Lazy-load the Dia model to avoid loading it until needed""" global _model # Check if Dia is available before attempting to load if not DIA_AVAILABLE: logger.warning("Dia is not available, cannot load model") raise ImportError("Dia module is not available") if _model is None: logger.info("Loading Dia model...") try: # Check if torch is available with correct version logger.info(f"PyTorch version: {torch.__version__}") logger.info(f"CUDA available: {torch.cuda.is_available()}") if torch.cuda.is_available(): logger.info(f"CUDA version: {torch.version.cuda}") logger.info(f"GPU device: {torch.cuda.get_device_name(0)}") # Check if model path exists logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}") # Load the model with detailed logging logger.info("Initializing Dia model...") _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16") # Log model details logger.info(f"Dia model loaded successfully") logger.info(f"Model type: {type(_model).__name__}") # Check if model has parameters method (PyTorch models do, but Dia might not) if hasattr(_model, 'parameters'): logger.info(f"Model device: {next(_model.parameters()).device}") else: logger.info("Model device: Device information not available for Dia model") except ImportError as import_err: logger.error(f"Import error loading Dia model: {import_err}") logger.error(f"This may indicate missing dependencies") raise except FileNotFoundError as file_err: logger.error(f"File not found error loading Dia model: {file_err}") logger.error(f"Model path may be incorrect or inaccessible") raise except Exception as e: logger.error(f"Error loading Dia model: {e}", exc_info=True) logger.error(f"Error type: {type(e).__name__}") logger.error(f"This may indicate incompatible versions or missing CUDA support") raise return _model def generate_speech(text: str, language: str = "zh") -> str: """Public interface for TTS generation using Dia model This is a legacy function maintained for backward compatibility. New code should use the factory pattern implementation directly. Args: text (str): Input text to synthesize language (str): Language code (not used in Dia model, kept for API compatibility) Returns: str: Path to the generated audio file """ logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}") # Check if Dia is available if not DIA_AVAILABLE: logger.warning("Dia is not available, falling back to dummy TTS engine") from utils.tts_base import DummyTTSEngine dummy_engine = DummyTTSEngine(language) return dummy_engine.generate_speech(text) # Use the new implementation via factory pattern try: # Import here to avoid circular imports from utils.tts_engines import DiaTTSEngine # Create a Dia engine and generate speech dia_engine = DiaTTSEngine(language) return dia_engine.generate_speech(text) except ModuleNotFoundError as e: logger.error(f"Module not found error in Dia generate_speech: {str(e)}") if "dac" in str(e): logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS") # Fall back to dummy TTS from utils.tts_base import DummyTTSEngine dummy_engine = DummyTTSEngine(language) return dummy_engine.generate_speech(text) except Exception as e: logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True) # Fall back to dummy TTS from utils.tts_base import DummyTTSEngine dummy_engine = DummyTTSEngine(language) return dummy_engine.generate_speech(text)