Spaces:
Running
Running
File size: 5,296 Bytes
030c851 cb90410 030c851 cb90410 030c851 cb90410 030c851 91223c9 030c851 91223c9 2d176f4 91223c9 030c851 91223c9 030c851 c549dab 030c851 3ed3b5a 030c851 3ed3b5a 4a9bb1a cb90410 4a9bb1a cb90410 4a9bb1a cb90410 3ed3b5a cb90410 3ed3b5a 4a9bb1a 3ed3b5a cb90410 3ed3b5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import os
import time
import logging
import numpy as np
import soundfile as sf
from pathlib import Path
from typing import Optional
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Flag to track Dia availability
DIA_AVAILABLE = False
# Try to import required dependencies
try:
import torch
# Try to import Dia, which will try to import dac
try:
from dia.model import Dia
DIA_AVAILABLE = True
logger.info("Dia TTS engine is available")
except ModuleNotFoundError as e:
if "dac" in str(e):
logger.warning("Dia TTS engine is not available due to missing 'dac' module")
else:
logger.warning(f"Dia TTS engine is not available: {str(e)}")
DIA_AVAILABLE = False
except ImportError:
logger.warning("Torch not available, Dia TTS engine cannot be used")
DIA_AVAILABLE = False
# Constants
DEFAULT_SAMPLE_RATE = 44100
DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
# Global model instance (lazy loaded)
_model = None
def _get_model():
"""Lazy-load the Dia model to avoid loading it until needed"""
global _model
# Check if Dia is available before attempting to load
if not DIA_AVAILABLE:
logger.warning("Dia is not available, cannot load model")
raise ImportError("Dia module is not available")
if _model is None:
logger.info("Loading Dia model...")
try:
# Check if torch is available with correct version
logger.info(f"PyTorch version: {torch.__version__}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"CUDA version: {torch.version.cuda}")
logger.info(f"GPU device: {torch.cuda.get_device_name(0)}")
# Check if model path exists
logger.info(f"Attempting to load model from: {DEFAULT_MODEL_NAME}")
# Load the model with detailed logging
logger.info("Initializing Dia model...")
_model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
# Log model details
logger.info(f"Dia model loaded successfully")
logger.info(f"Model type: {type(_model).__name__}")
# Check if model has parameters method (PyTorch models do, but Dia might not)
if hasattr(_model, 'parameters'):
logger.info(f"Model device: {next(_model.parameters()).device}")
else:
logger.info("Model device: Device information not available for Dia model")
except ImportError as import_err:
logger.error(f"Import error loading Dia model: {import_err}")
logger.error(f"This may indicate missing dependencies")
raise
except FileNotFoundError as file_err:
logger.error(f"File not found error loading Dia model: {file_err}")
logger.error(f"Model path may be incorrect or inaccessible")
raise
except Exception as e:
logger.error(f"Error loading Dia model: {e}", exc_info=True)
logger.error(f"Error type: {type(e).__name__}")
logger.error(f"This may indicate incompatible versions or missing CUDA support")
raise
return _model
def generate_speech(text: str, language: str = "zh") -> str:
"""Public interface for TTS generation using Dia model
This is a legacy function maintained for backward compatibility.
New code should use the factory pattern implementation directly.
Args:
text (str): Input text to synthesize
language (str): Language code (not used in Dia model, kept for API compatibility)
Returns:
str: Path to the generated audio file
"""
logger.info(f"Legacy Dia generate_speech called with text length: {len(text)}")
# Check if Dia is available
if not DIA_AVAILABLE:
logger.warning("Dia is not available, falling back to dummy TTS engine")
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine(language)
return dummy_engine.generate_speech(text)
# Use the new implementation via factory pattern
try:
# Import here to avoid circular imports
from utils.tts_engines import DiaTTSEngine
# Create a Dia engine and generate speech
dia_engine = DiaTTSEngine(language)
return dia_engine.generate_speech(text)
except ModuleNotFoundError as e:
logger.error(f"Module not found error in Dia generate_speech: {str(e)}")
if "dac" in str(e):
logger.warning("Dia TTS engine failed due to missing 'dac' module, falling back to dummy TTS")
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine(language)
return dummy_engine.generate_speech(text)
except Exception as e:
logger.error(f"Error in legacy Dia generate_speech: {str(e)}", exc_info=True)
# Fall back to dummy TTS
from utils.tts_base import DummyTTSEngine
dummy_engine = DummyTTSEngine(language)
return dummy_engine.generate_speech(text) |