teachingAssistant / utils /tts_cosyvoice2.py
Michael Hu
fix circular dependency
aaa0814
import logging
import numpy as np
import soundfile as sf
from typing import Optional, Generator, Tuple
from utils.tts_base import TTSBase
# Configure logging
logger = logging.getLogger(__name__)
# Flag to track CosyVoice2 availability
COSYVOICE2_AVAILABLE = False
DEFAULT_SAMPLE_RATE = 24000
# Try to import CosyVoice2 dependencies
try:
import torch
# Import CosyVoice2 - assuming it's installed and has a similar API to Dia
# since they're both from nari-labs according to the GitHub link
from cosyvoice2.model import CosyVoice2
COSYVOICE2_AVAILABLE = True
logger.info("CosyVoice2 TTS engine is available")
except ImportError:
logger.warning("CosyVoice2 TTS engine is not available")
except ModuleNotFoundError as e:
logger.warning(f"CosyVoice2 TTS engine is not available: {str(e)}")
COSYVOICE2_AVAILABLE = False
def _get_model():
"""Lazy-load the CosyVoice2 model
Returns:
CosyVoice2 or None: The CosyVoice2 model or None if not available
"""
if not COSYVOICE2_AVAILABLE:
logger.warning("CosyVoice2 TTS engine is not available")
return None
try:
import torch
from cosyvoice2.model import CosyVoice2
# Initialize the model
model = CosyVoice2.from_pretrained()
logger.info("CosyVoice2 model successfully loaded")
return model
except ImportError as e:
logger.error(f"Failed to import CosyVoice2 dependencies: {str(e)}")
return None
except FileNotFoundError as e:
logger.error(f"Failed to load CosyVoice2 model files: {str(e)}")
return None
except Exception as e:
logger.error(f"Failed to initialize CosyVoice2 model: {str(e)}")
return None
class CosyVoice2TTS(TTSBase):
"""CosyVoice2 TTS engine implementation
This engine uses the CosyVoice2 model for TTS generation.
"""
def __init__(self, lang_code: str = 'z'):
"""Initialize the CosyVoice2 TTS engine
Args:
lang_code (str): Language code for the engine
"""
super().__init__(lang_code)
self.model = None
def _ensure_model(self):
"""Ensure the model is loaded
Returns:
bool: True if model is available, False otherwise
"""
if self.model is None:
self.model = _get_model()
return self.model is not None
def generate_speech(self, text: str, voice: str = 'default', speed: float = 1.0) -> Optional[str]:
"""Generate speech using CosyVoice2 TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID (may not be used in CosyVoice2)
speed (float): Speech speed multiplier (may not be used in CosyVoice2)
Returns:
Optional[str]: Path to the generated audio file or None if generation fails
"""
logger.info(f"Generating speech with CosyVoice2 for text length: {len(text)}")
# Check if CosyVoice2 is available
if not COSYVOICE2_AVAILABLE:
logger.error("CosyVoice2 TTS engine is not available")
return None
# Ensure model is loaded
if not self._ensure_model():
logger.error("Failed to load CosyVoice2 model")
return None
try:
import torch
# Generate unique output path
output_path = self._generate_output_path(prefix="cosyvoice2")
# Generate audio
with torch.inference_mode():
# Assuming CosyVoice2 has a similar API to Dia
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
use_torch_compile=False,
verbose=False
)
if output_audio_np is not None:
logger.info(f"Successfully generated audio with CosyVoice2 (length: {len(output_audio_np)})")
sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
logger.info(f"CosyVoice2 audio generation complete: {output_path}")
return output_path
else:
logger.error("CosyVoice2 model returned None for audio output")
return None
except Exception as e:
logger.error(f"Error generating speech with CosyVoice2: {str(e)}", exc_info=True)
return None
def generate_speech_stream(self, text: str, voice: str = 'default', speed: float = 1.0) -> Generator[Tuple[int, np.ndarray], None, None]:
"""Generate speech stream using CosyVoice2 TTS engine
Args:
text (str): Input text to synthesize
voice (str): Voice ID (may not be used in CosyVoice2)
speed (float): Speech speed multiplier (may not be used in CosyVoice2)
Yields:
tuple: (sample_rate, audio_data) pairs for each segment
"""
logger.info(f"Generating speech stream with CosyVoice2 for text length: {len(text)}")
# Check if CosyVoice2 is available
if not COSYVOICE2_AVAILABLE:
logger.error("CosyVoice2 TTS engine is not available")
return
# Ensure model is loaded
if not self._ensure_model():
logger.error("Failed to load CosyVoice2 model")
return
try:
import torch
# Generate audio
with torch.inference_mode():
# Assuming CosyVoice2 has a similar API to Dia
output_audio_np = self.model.generate(
text,
max_tokens=None,
cfg_scale=3.0,
temperature=1.3,
top_p=0.95,
use_torch_compile=False,
verbose=False
)
if output_audio_np is not None:
logger.info(f"Successfully generated audio with CosyVoice2 (length: {len(output_audio_np)})")
yield DEFAULT_SAMPLE_RATE, output_audio_np
else:
logger.error("CosyVoice2 model returned None for audio output")
return
except Exception as e:
logger.error(f"Error generating speech stream with CosyVoice2: {str(e)}", exc_info=True)
return