"""Audio transcription and text-to-speech functions""" import os import asyncio import tempfile import soundfile as sf import torch import numpy as np from logger import logger from client import MCP_AVAILABLE, call_agent, get_mcp_session, get_cached_mcp_tools import config from models import TTS_AVAILABLE, SNAC_AVAILABLE, WHISPER_AVAILABLE, initialize_tts_model, initialize_whisper_model # Maya1 constants (from maya1 docs) CODE_START_TOKEN_ID = 128257 CODE_END_TOKEN_ID = 128258 CODE_TOKEN_OFFSET = 128266 SNAC_MIN_ID = 128266 SNAC_MAX_ID = 156937 SOH_ID = 128259 EOH_ID = 128260 SOA_ID = 128261 TEXT_EOT_ID = 128009 AUDIO_SAMPLE_RATE = 24000 # Default voice description for Maya1 DEFAULT_VOICE_DESCRIPTION = "Realistic male voice in the 30s age with a american accent. Normal pitch, warm timbre, conversational pacing, neutral tone delivery at medium intensity, podcast domain, narrator role, neutral delivery" import spaces try: import nest_asyncio except ImportError: nest_asyncio = None async def transcribe_audio_gemini(audio_path: str) -> str: """Transcribe audio using Gemini MCP transcribe_audio tool""" if not MCP_AVAILABLE: return "" try: session = await get_mcp_session() if session is None: logger.warning("MCP session not available for transcription") return "" tools = await get_cached_mcp_tools() transcribe_tool = None for tool in tools: if tool.name == "transcribe_audio": transcribe_tool = tool logger.info(f"Found MCP transcribe_audio tool: {tool.name}") break if not transcribe_tool: logger.warning("transcribe_audio MCP tool not found, falling back to generate_content") # Fallback to using generate_content audio_path_abs = os.path.abspath(audio_path) files = [{"path": audio_path_abs}] system_prompt = "You are a professional transcription service. Provide accurate, well-formatted transcripts." user_prompt = "Please transcribe this audio file. Include speaker identification if multiple speakers are present, and format it with proper punctuation and paragraphs, remove mumble, ignore non-verbal noises." result = await call_agent( user_prompt=user_prompt, system_prompt=system_prompt, files=files, model=config.GEMINI_MODEL_LITE, temperature=0.2 ) return result.strip() # Use the transcribe_audio tool audio_path_abs = os.path.abspath(audio_path) result = await session.call_tool( transcribe_tool.name, arguments={"audio_path": audio_path_abs} ) if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): transcribed_text = item.text.strip() if transcribed_text: logger.info(f"✅ Transcribed via MCP transcribe_audio tool: {transcribed_text[:50]}...") return transcribed_text logger.warning("MCP transcribe_audio returned empty result") return "" except Exception as e: logger.error(f"Gemini transcription error: {e}") return "" @spaces.GPU(max_duration=60) def transcribe_audio_whisper(audio_path: str) -> str: """Transcribe audio using Whisper model from Hugging Face""" if not WHISPER_AVAILABLE: logger.warning("[ASR] Whisper not available for transcription") return "" try: logger.info(f"[ASR] Starting Whisper transcription for: {audio_path}") if config.global_whisper_model is None: logger.info("[ASR] Whisper model not loaded, initializing now (on-demand)...") try: initialize_whisper_model() if config.global_whisper_model is None: logger.error("[ASR] Failed to initialize Whisper model - check logs for errors") return "" else: logger.info("[ASR] ✅ Whisper model loaded successfully on-demand!") except Exception as e: logger.error(f"[ASR] Error initializing Whisper model: {e}") import traceback logger.error(f"[ASR] Initialization traceback: {traceback.format_exc()}") return "" if config.global_whisper_model is None: logger.error("[ASR] Whisper model is still None after initialization attempt") return "" # Extract processor and model from stored dict processor = config.global_whisper_model["processor"] model = config.global_whisper_model["model"] logger.info("[ASR] Loading audio file...") import torch import numpy as np # Check if audio file exists if not os.path.exists(audio_path): logger.error(f"[ASR] Audio file not found: {audio_path}") return "" try: # Use soundfile to load audio (more reliable, doesn't require torchcodec) logger.info(f"[ASR] Loading audio with soundfile: {audio_path}") audio_data, sample_rate = sf.read(audio_path, dtype='float32') logger.info(f"[ASR] Loaded audio with soundfile: shape={audio_data.shape}, sample_rate={sample_rate}, dtype={audio_data.dtype}") # Convert to torch tensor and ensure it's 2D (channels, samples) if len(audio_data.shape) == 1: # Mono audio - add channel dimension waveform = torch.from_numpy(audio_data).unsqueeze(0) else: # Multi-channel - transpose to (channels, samples) waveform = torch.from_numpy(audio_data).T logger.info(f"[ASR] Converted to tensor: shape={waveform.shape}, dtype={waveform.dtype}") # Ensure audio is mono (single channel) if waveform.shape[0] > 1: logger.info(f"[ASR] Converting {waveform.shape[0]}-channel audio to mono") waveform = torch.mean(waveform, dim=0, keepdim=True) # Resample to 16kHz if needed (Whisper expects 16kHz) if sample_rate != 16000: logger.info(f"[ASR] Resampling from {sample_rate}Hz to 16000Hz") # Use scipy or librosa for resampling if available, otherwise use simple interpolation try: from scipy import signal # Resample using scipy num_samples = int(len(waveform[0]) * 16000 / sample_rate) resampled = signal.resample(waveform[0].numpy(), num_samples) waveform = torch.from_numpy(resampled).unsqueeze(0) sample_rate = 16000 logger.info(f"[ASR] Resampled using scipy: new shape={waveform.shape}") except ImportError: # Fallback: simple linear interpolation (scipy not available) logger.info("[ASR] scipy not available, using simple linear interpolation for resampling") num_samples = int(len(waveform[0]) * 16000 / sample_rate) waveform_1d = waveform[0].numpy() indices = np.linspace(0, len(waveform_1d) - 1, num_samples) resampled = np.interp(indices, np.arange(len(waveform_1d)), waveform_1d) waveform = torch.from_numpy(resampled).unsqueeze(0) sample_rate = 16000 logger.info(f"[ASR] Resampled using simple interpolation: new shape={waveform.shape}") logger.info(f"[ASR] Audio ready: shape={waveform.shape}, sample_rate={sample_rate}") logger.info("[ASR] Processing audio with Whisper processor...") # Process audio - convert to numpy and ensure it's the right shape audio_array = waveform.squeeze().numpy() logger.info(f"[ASR] Audio array shape: {audio_array.shape}, dtype: {audio_array.dtype}") # Process audio inputs = processor(audio_array, sampling_rate=sample_rate, return_tensors="pt") logger.info(f"[ASR] Processor inputs: {list(inputs.keys())}") # Move inputs to same device as model device = next(model.parameters()).device logger.info(f"[ASR] Model device: {device}") inputs = {k: v.to(device) for k, v in inputs.items()} logger.info("[ASR] Running Whisper model.generate()...") # Generate transcription with proper parameters # Whisper expects input_features as the main parameter if "input_features" not in inputs: logger.error(f"[ASR] Missing input_features in processor output. Keys: {list(inputs.keys())}") return "" input_features = inputs["input_features"] logger.info(f"[ASR] Input features shape: {input_features.shape}, dtype: {input_features.dtype}") # Convert input features to match model dtype (float16) model_dtype = next(model.parameters()).dtype if input_features.dtype != model_dtype: logger.info(f"[ASR] Converting input features from {input_features.dtype} to {model_dtype} to match model") input_features = input_features.to(dtype=model_dtype) logger.info(f"[ASR] Converted input features dtype: {input_features.dtype}") with torch.no_grad(): try: # Whisper generate with proper parameters generated_ids = model.generate( input_features, max_length=448, # Whisper default max length num_beams=5, language=None, # Auto-detect language task="transcribe", return_timestamps=False ) logger.info(f"[ASR] Generated IDs shape: {generated_ids.shape}, dtype: {generated_ids.dtype}") logger.info(f"[ASR] Generated IDs sample: {generated_ids[0][:20] if len(generated_ids) > 0 else 'empty'}") except Exception as gen_error: logger.error(f"[ASR] Error in model.generate(): {gen_error}") import traceback logger.error(f"[ASR] Generate traceback: {traceback.format_exc()}") # Try simpler generation without optional parameters logger.info("[ASR] Retrying with minimal parameters...") try: # Ensure dtype is correct for retry too if input_features.dtype != model_dtype: input_features = input_features.to(dtype=model_dtype) generated_ids = model.generate(input_features) logger.info(f"[ASR] Retry successful, generated IDs shape: {generated_ids.shape}") except Exception as retry_error: logger.error(f"[ASR] Retry also failed: {retry_error}") return "" logger.info("[ASR] Decoding transcription...") # Decode transcription transcribed_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() if transcribed_text: logger.info(f"[ASR] ✅ Transcription successful: {transcribed_text[:100]}...") logger.info(f"[ASR] Transcription length: {len(transcribed_text)} characters") else: logger.warning("[ASR] Whisper returned empty transcription") logger.warning(f"[ASR] Generated IDs: {generated_ids}") logger.warning(f"[ASR] Decoded (before strip): {processor.batch_decode(generated_ids, skip_special_tokens=False)[0]}") return transcribed_text except Exception as audio_error: logger.error(f"[ASR] Error processing audio file: {audio_error}") import traceback logger.error(f"[ASR] Audio processing traceback: {traceback.format_exc()}") return "" except Exception as e: logger.error(f"[ASR] Whisper transcription error: {e}") import traceback logger.error(f"[ASR] Full traceback: {traceback.format_exc()}") return "" def transcribe_audio(audio): """Transcribe audio to text using Whisper (primary) or Gemini MCP (fallback)""" if audio is None: logger.warning("[ASR] No audio provided") return "" try: # Convert audio input to file path if isinstance(audio, str): audio_path = audio elif isinstance(audio, tuple): sample_rate, audio_data = audio logger.info(f"[ASR] Processing audio tuple: sample_rate={sample_rate}, data_shape={audio_data.shape if hasattr(audio_data, 'shape') else 'unknown'}") with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: sf.write(tmp_file.name, audio_data, samplerate=sample_rate) audio_path = tmp_file.name logger.info(f"[ASR] Created temporary audio file: {audio_path}") else: audio_path = audio logger.info(f"[ASR] Attempting transcription with Whisper (primary method)...") # Try Whisper first (primary method) if WHISPER_AVAILABLE: try: transcribed = transcribe_audio_whisper(audio_path) if transcribed: logger.info(f"[ASR] ✅ Successfully transcribed via Whisper: {transcribed[:50]}...") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass return transcribed else: logger.warning("[ASR] Whisper transcription returned empty, trying fallback...") except Exception as e: logger.error(f"[ASR] Whisper transcription failed: {e}, trying fallback...") else: logger.warning("[ASR] Whisper not available, trying Gemini fallback...") # Fallback to Gemini MCP if Whisper fails or is unavailable if MCP_AVAILABLE: try: logger.info("[ASR] Attempting transcription with Gemini MCP (fallback)...") loop = asyncio.get_event_loop() if loop.is_running(): if nest_asyncio: transcribed = nest_asyncio.run(transcribe_audio_gemini(audio_path)) if transcribed: logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass return transcribed else: logger.error("[ASR] nest_asyncio not available for nested async transcription") else: transcribed = loop.run_until_complete(transcribe_audio_gemini(audio_path)) if transcribed: logger.info(f"[ASR] Transcribed via Gemini MCP (fallback): {transcribed[:50]}...") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass return transcribed except Exception as e: logger.error(f"[ASR] Gemini MCP transcription error: {e}") # Clean up temp file if we created it if isinstance(audio, tuple) and os.path.exists(audio_path): try: os.unlink(audio_path) except: pass logger.warning("[ASR] All transcription methods failed") return "" except Exception as e: logger.error(f"[ASR] Transcription error: {e}") import traceback logger.debug(f"[ASR] Full traceback: {traceback.format_exc()}") return "" async def generate_speech_mcp(text: str) -> str: """Generate speech using MCP text_to_speech tool (fallback path).""" if not MCP_AVAILABLE: return None try: session = await get_mcp_session() if session is None: logger.warning("MCP session not available for TTS") return None tools = await get_cached_mcp_tools() tts_tool = None for tool in tools: if tool.name == "text_to_speech": tts_tool = tool logger.info(f"Found MCP text_to_speech tool: {tool.name}") break if not tts_tool: # Fallback: search for any TTS-related tool for tool in tools: tool_name_lower = tool.name.lower() if "tts" in tool_name_lower or "speech" in tool_name_lower or "synthesize" in tool_name_lower: tts_tool = tool logger.info(f"Found MCP TTS tool (fallback): {tool.name}") break if tts_tool: result = await session.call_tool( tts_tool.name, arguments={"text": text, "language": "en"} ) if hasattr(result, 'content') and result.content: for item in result.content: if hasattr(item, 'text'): text_result = item.text # Check if it's a signal to use local TTS if text_result == "USE_LOCAL_TTS": logger.info("MCP TTS tool indicates client-side TTS should be used") return None # Return None to trigger client-side TTS elif os.path.exists(text_result): return text_result elif hasattr(item, 'data') and item.data: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: tmp_file.write(item.data) return tmp_file.name return None except Exception as e: logger.warning(f"MCP TTS error: {e}") return None def _generate_speech_via_mcp(text: str): """Helper to generate speech via MCP in a synchronous context.""" if not MCP_AVAILABLE: return None try: loop = asyncio.get_event_loop() if loop.is_running(): if nest_asyncio: audio_path = nest_asyncio.run(generate_speech_mcp(text)) else: logger.error("nest_asyncio not available for nested async TTS via MCP") return None else: audio_path = loop.run_until_complete(generate_speech_mcp(text)) if audio_path: logger.info("Generated speech via MCP") return audio_path except Exception as e: logger.warning(f"MCP TTS error (sync wrapper): {e}") return None def build_maya1_prompt(tokenizer, description: str, text: str) -> str: """Build formatted prompt for Maya1.""" soh_token = tokenizer.decode([SOH_ID]) eoh_token = tokenizer.decode([EOH_ID]) soa_token = tokenizer.decode([SOA_ID]) sos_token = tokenizer.decode([CODE_START_TOKEN_ID]) eot_token = tokenizer.decode([TEXT_EOT_ID]) bos_token = tokenizer.bos_token formatted_text = f' {text}' prompt = ( soh_token + bos_token + formatted_text + eot_token + eoh_token + soa_token + sos_token ) return prompt def unpack_snac_from_7(snac_tokens: list) -> list: """Unpack 7-token SNAC frames to 3 hierarchical levels.""" if snac_tokens and snac_tokens[-1] == CODE_END_TOKEN_ID: snac_tokens = snac_tokens[:-1] frames = len(snac_tokens) // 7 snac_tokens = snac_tokens[:frames * 7] if frames == 0: return [[], [], []] l1, l2, l3 = [], [], [] for i in range(frames): slots = snac_tokens[i*7:(i+1)*7] l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) l2.extend([ (slots[1] - CODE_TOKEN_OFFSET) % 4096, (slots[4] - CODE_TOKEN_OFFSET) % 4096, ]) l3.extend([ (slots[2] - CODE_TOKEN_OFFSET) % 4096, (slots[3] - CODE_TOKEN_OFFSET) % 4096, (slots[5] - CODE_TOKEN_OFFSET) % 4096, (slots[6] - CODE_TOKEN_OFFSET) % 4096, ]) return [l1, l2, l3] def _generate_speech_with_gpu(text: str, description: str = None): """Internal GPU-decorated function for TTS generation when TTS is available.""" if config.global_tts_model is None: logger.info("[TTS] TTS model not loaded, initializing...") initialize_tts_model() if config.global_tts_model is None: logger.error("[TTS] TTS model not available. Please check dependencies.") return None # Check if it's the new Maya1 format (dictionary) or old format if not isinstance(config.global_tts_model, dict): logger.error("[TTS] TTS model format is incorrect. Expected dictionary with model, tokenizer, snac_model.") return None try: model = config.global_tts_model["model"] tokenizer = config.global_tts_model["tokenizer"] snac_model = config.global_tts_model["snac_model"] # Use default description if not provided if description is None: description = DEFAULT_VOICE_DESCRIPTION logger.info("[TTS] Running Maya1 TTS generation...") # Build prompt prompt = build_maya1_prompt(tokenizer, description, text) inputs = tokenizer(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} # Generate tokens with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=1500, min_new_tokens=28, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=CODE_END_TOKEN_ID, pad_token_id=tokenizer.pad_token_id, ) # Extract SNAC tokens generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() # Find EOS and extract SNAC codes eos_idx = generated_ids.index(CODE_END_TOKEN_ID) if CODE_END_TOKEN_ID in generated_ids else len(generated_ids) snac_tokens = [t for t in generated_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID] if len(snac_tokens) < 7: logger.error(f"[TTS] Not enough tokens generated ({len(snac_tokens)}). Try different text or increase max_tokens.") return None # Unpack and decode levels = unpack_snac_from_7(snac_tokens) frames = len(levels[0]) device = "cuda" if torch.cuda.is_available() else "cpu" codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels] with torch.inference_mode(): z_q = snac_model.quantizer.from_codes(codes_tensor) audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() # Trim warmup if len(audio) > 2048: audio = audio[2048:] # Convert to WAV and save to temporary file audio_int16 = (audio * 32767).astype(np.int16) # Create temporary file with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: tmp_path = tmp_file.name # Save audio sf.write(tmp_path, audio_int16, AUDIO_SAMPLE_RATE) duration = len(audio) / AUDIO_SAMPLE_RATE logger.info(f"[TTS] ✅ Speech generated successfully: {tmp_path} ({duration:.2f}s)") return tmp_path except Exception as e: logger.error(f"[TTS] TTS error (local maya1): {e}") import traceback logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}") return None @spaces.GPU(max_duration=120) def _generate_speech_gpu_wrapper(text: str): """GPU wrapper for TTS generation - only called when TTS is available.""" return _generate_speech_with_gpu(text) def generate_speech(text: str): """Generate speech from text using local maya1 TTS model (with MCP fallback). The primary path uses the local TTS model (maya-research/maya1). MCP-based TTS is only used as a last-resort fallback if the local model is unavailable or fails. This function checks TTS availability before attempting GPU allocation. """ if not text or len(text.strip()) == 0: logger.warning("[TTS] Empty text provided") return None logger.info(f"[TTS] Generating speech for text: {text[:50]}...") # Check TTS availability first - avoid GPU allocation if not available # Use SNAC_AVAILABLE for Maya1, but keep TTS_AVAILABLE check for backward compatibility if not SNAC_AVAILABLE: logger.warning("[TTS] SNAC library not installed (required for Maya1). Trying MCP fallback...") # Try MCP-based TTS if available (doesn't require GPU) audio_path = _generate_speech_via_mcp(text) if audio_path: logger.info(f"[TTS] ✅ Generated via MCP fallback: {audio_path}") return audio_path else: logger.error("[TTS] ❌ SNAC library not installed and MCP fallback failed. Please install: pip install snac") return None # TTS is available - use GPU-decorated function try: audio_path = _generate_speech_gpu_wrapper(text) if audio_path: return audio_path else: # GPU generation failed, try MCP fallback logger.warning("[TTS] Local TTS generation failed, trying MCP fallback...") return _generate_speech_via_mcp(text) except Exception as e: logger.error(f"[TTS] GPU TTS generation error: {e}") import traceback logger.debug(f"[TTS] Full traceback: {traceback.format_exc()}") # Try MCP fallback on error logger.info("[TTS] Attempting MCP fallback after error...") return _generate_speech_via_mcp(text)