Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| AI Avatar Chat - HF Spaces Optimized Version | |
| With robust import handling and graceful fallbacks | |
| """ | |
| import os | |
| # STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads | |
| IS_HF_SPACE = any([ | |
| os.getenv("SPACE_ID"), | |
| os.getenv("SPACE_AUTHOR_NAME"), | |
| os.getenv("SPACES_BUILDKIT_VERSION"), | |
| "/home/user/app" in os.getcwd() | |
| ]) | |
| if IS_HF_SPACE: | |
| # Force TTS-only mode to prevent storage limit exceeded | |
| os.environ["DISABLE_MODEL_DOWNLOAD"] = "1" | |
| os.environ["TTS_ONLY_MODE"] = "1" | |
| os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1" | |
| print("?? STORAGE OPTIMIZATION: Detected HF Space environment") | |
| print("??? TTS-only mode ENABLED (video generation disabled for storage limits)") | |
| print("?? Model auto-download DISABLED to prevent storage exceeded error") | |
| # Core imports (required) | |
| import torch | |
| import tempfile | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, HttpUrl | |
| import subprocess | |
| import json | |
| from pathlib import Path | |
| import logging | |
| import requests | |
| from urllib.parse import urlparse | |
| from PIL import Image | |
| import io | |
| from typing import Optional | |
| import asyncio | |
| import time | |
| # Optional imports with graceful fallbacks | |
| try: | |
| import aiohttp | |
| AIOHTTP_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: aiohttp not available") | |
| AIOHTTP_AVAILABLE = False | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| DOTENV_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: python-dotenv not available, skipping .env file loading") | |
| DOTENV_AVAILABLE = False | |
| def load_dotenv(): | |
| pass | |
| try: | |
| from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible | |
| setup_hf_spaces_environment() | |
| HF_FIX_AVAILABLE = True | |
| except ImportError: | |
| print("Warning: HF Spaces fix not available") | |
| HF_FIX_AVAILABLE = False | |
| # STREAMING MODEL SUPPORT for HF Spaces | |
| try: | |
| from streaming_video_engine import streaming_engine | |
| STREAMING_ENABLED = True | |
| print("?? Streaming video engine loaded successfully") | |
| except ImportError: | |
| STREAMING_ENABLED = False | |
| print("?? Streaming engine not available, using TTS-only mode") | |
| print("? Core imports loaded successfully") | |
| print(f"?? Optional features: aiohttp={AIOHTTP_AVAILABLE}, dotenv={DOTENV_AVAILABLE}, hf_fix={HF_FIX_AVAILABLE}, streaming={STREAMING_ENABLED}") | |
| import os | |
| # STORAGE OPTIMIZATION: Check if running on HF Spaces and disable model downloads | |
| IS_HF_SPACE = any([ | |
| os.getenv("SPACE_ID"), | |
| os.getenv("SPACE_AUTHOR_NAME"), | |
| os.getenv("SPACES_BUILDKIT_VERSION"), | |
| "/home/user/app" in os.getcwd() | |
| ]) | |
| if IS_HF_SPACE: | |
| # Force TTS-only mode to prevent storage limit exceeded | |
| os.environ["DISABLE_MODEL_DOWNLOAD"] = "1" | |
| os.environ["TTS_ONLY_MODE"] = "1" | |
| os.environ["HF_SPACE_STORAGE_OPTIMIZED"] = "1" | |
| print("?? STORAGE OPTIMIZATION: Detected HF Space environment") | |
| print("??? TTS-only mode ENABLED (video generation disabled for storage limits)") | |
| print("?? Model auto-download DISABLED to prevent storage exceeded error") | |
| import os | |
| import torch | |
| import tempfile | |
| import gradio as gr | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, HttpUrl | |
| import subprocess | |
| import json | |
| from pathlib import Path | |
| import logging | |
| import requests | |
| from urllib.parse import urlparse | |
| from PIL import Image | |
| import io | |
| from typing import Optional | |
| import aiohttp | |
| import asyncio | |
| # Optional dotenv import for environment variables | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv() # Load .env file if it exists | |
| except ImportError: | |
| print("Warning: python-dotenv not available, skipping .env file loading") | |
| def load_dotenv(): | |
| pass # No-op function | |
| # CRITICAL: HF Spaces compatibility fix | |
| try: | |
| from hf_spaces_fix import setup_hf_spaces_environment, HFSpacesCompatible | |
| setup_hf_spaces_environment() | |
| except ImportError: | |
| print('Warning: HF Spaces fix not available') | |
| # Load environment variables | |
| load_dotenv() | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Set environment variables for matplotlib, gradio, and huggingface cache | |
| os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' | |
| os.environ['GRADIO_ALLOW_FLAGGING'] = 'never' | |
| os.environ['HF_HOME'] = '/tmp/huggingface' | |
| # Use HF_HOME instead of deprecated TRANSFORMERS_CACHE | |
| os.environ['HF_DATASETS_CACHE'] = '/tmp/huggingface/datasets' | |
| os.environ['HUGGINGFACE_HUB_CACHE'] = '/tmp/huggingface/hub' | |
| # FastAPI app will be created after lifespan is defined | |
| # Create directories with proper permissions | |
| os.makedirs("outputs", exist_ok=True) | |
| os.makedirs("/tmp/matplotlib", exist_ok=True) | |
| os.makedirs("/tmp/huggingface", exist_ok=True) | |
| os.makedirs("/tmp/huggingface/transformers", exist_ok=True) | |
| os.makedirs("/tmp/huggingface/datasets", exist_ok=True) | |
| os.makedirs("/tmp/huggingface/hub", exist_ok=True) | |
| # Mount static files for serving generated videos | |
| def get_video_url(output_path: str) -> str: | |
| """Convert local file path to accessible URL""" | |
| try: | |
| from pathlib import Path | |
| filename = Path(output_path).name | |
| # For HuggingFace Spaces, construct the URL | |
| base_url = "https://bravedims-ai-avatar-chat.hf.space" | |
| video_url = f"{base_url}/outputs/{filename}" | |
| logger.info(f"Generated video URL: {video_url}") | |
| return video_url | |
| except Exception as e: | |
| logger.error(f"Error creating video URL: {e}") | |
| return output_path # Fallback to original path | |
| # Pydantic models for request/response | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| text_to_speech: Optional[str] = None # Text to convert to speech | |
| audio_url: Optional[HttpUrl] = None # Direct audio URL | |
| voice_id: Optional[str] = "21m00Tcm4TlvDq8ikWAM" # Voice profile ID | |
| image_url: Optional[HttpUrl] = None | |
| guidance_scale: float = 5.0 | |
| audio_scale: float = 3.0 | |
| num_steps: int = 30 | |
| sp_size: int = 1 | |
| tea_cache_l1_thresh: Optional[float] = None | |
| class GenerateResponse(BaseModel): | |
| message: str | |
| output_path: str | |
| processing_time: float | |
| audio_generated: bool = False | |
| tts_method: Optional[str] = None | |
| # Try to import TTS clients, but make them optional | |
| try: | |
| from advanced_tts_client import AdvancedTTSClient | |
| ADVANCED_TTS_AVAILABLE = True | |
| logger.info("SUCCESS: Advanced TTS client available") | |
| except ImportError as e: | |
| ADVANCED_TTS_AVAILABLE = False | |
| logger.warning(f"WARNING: Advanced TTS client not available: {e}") | |
| # Always import the robust fallback | |
| try: | |
| from robust_tts_client import RobustTTSClient | |
| ROBUST_TTS_AVAILABLE = True | |
| logger.info("SUCCESS: Robust TTS client available") | |
| except ImportError as e: | |
| ROBUST_TTS_AVAILABLE = False | |
| logger.error(f"ERROR: Robust TTS client not available: {e}") | |
| class TTSManager: | |
| """Manages multiple TTS clients with fallback chain""" | |
| def __init__(self): | |
| # Initialize TTS clients based on availability | |
| self.advanced_tts = None | |
| self.robust_tts = None | |
| self.clients_loaded = False | |
| if ADVANCED_TTS_AVAILABLE: | |
| try: | |
| self.advanced_tts = AdvancedTTSClient() | |
| logger.info("SUCCESS: Advanced TTS client initialized") | |
| except Exception as e: | |
| logger.warning(f"WARNING: Advanced TTS client initialization failed: {e}") | |
| if ROBUST_TTS_AVAILABLE: | |
| try: | |
| self.robust_tts = RobustTTSClient() | |
| logger.info("SUCCESS: Robust TTS client initialized") | |
| except Exception as e: | |
| logger.error(f"ERROR: Robust TTS client initialization failed: {e}") | |
| if not self.advanced_tts and not self.robust_tts: | |
| logger.error("ERROR: No TTS clients available!") | |
| async def load_models(self): | |
| """Load TTS models""" | |
| try: | |
| logger.info("Loading TTS models...") | |
| # Try to load advanced TTS first | |
| if self.advanced_tts: | |
| try: | |
| logger.info("[PROCESS] Loading advanced TTS models (this may take a few minutes)...") | |
| success = await self.advanced_tts.load_models() | |
| if success: | |
| logger.info("SUCCESS: Advanced TTS models loaded successfully") | |
| else: | |
| logger.warning("WARNING: Advanced TTS models failed to load") | |
| except Exception as e: | |
| logger.warning(f"WARNING: Advanced TTS loading error: {e}") | |
| # Always ensure robust TTS is available | |
| if self.robust_tts: | |
| try: | |
| await self.robust_tts.load_model() | |
| logger.info("SUCCESS: Robust TTS fallback ready") | |
| except Exception as e: | |
| logger.error(f"ERROR: Robust TTS loading failed: {e}") | |
| self.clients_loaded = True | |
| return True | |
| except Exception as e: | |
| logger.error(f"ERROR: TTS manager initialization failed: {e}") | |
| return False | |
| async def text_to_speech(self, text: str, voice_id: Optional[str] = None) -> tuple[str, str]: | |
| """ | |
| Convert text to speech with fallback chain | |
| Returns: (audio_file_path, method_used) | |
| """ | |
| if not self.clients_loaded: | |
| logger.info("TTS models not loaded, loading now...") | |
| await self.load_models() | |
| logger.info(f"Generating speech: {text[:50]}...") | |
| logger.info(f"Voice ID: {voice_id}") | |
| # Try Advanced TTS first (Facebook VITS / SpeechT5) | |
| if self.advanced_tts: | |
| try: | |
| audio_path = await self.advanced_tts.text_to_speech(text, voice_id) | |
| return audio_path, "Facebook VITS/SpeechT5" | |
| except Exception as advanced_error: | |
| logger.warning(f"Advanced TTS failed: {advanced_error}") | |
| # Fall back to robust TTS | |
| if self.robust_tts: | |
| try: | |
| logger.info("Falling back to robust TTS...") | |
| audio_path = await self.robust_tts.text_to_speech(text, voice_id) | |
| return audio_path, "Robust TTS (Fallback)" | |
| except Exception as robust_error: | |
| logger.error(f"Robust TTS also failed: {robust_error}") | |
| # If we get here, all methods failed | |
| logger.error("All TTS methods failed!") | |
| raise HTTPException( | |
| status_code=500, | |
| detail="All TTS methods failed. Please check system configuration." | |
| ) | |
| async def get_available_voices(self): | |
| """Get available voice configurations""" | |
| try: | |
| if self.advanced_tts and hasattr(self.advanced_tts, 'get_available_voices'): | |
| return await self.advanced_tts.get_available_voices() | |
| except: | |
| pass | |
| # Return default voices if advanced TTS not available | |
| return { | |
| "21m00Tcm4TlvDq8ikWAM": "Female (Neutral)", | |
| "pNInz6obpgDQGcFmaJgB": "Male (Professional)", | |
| "EXAVITQu4vr4xnSDxMaL": "Female (Sweet)", | |
| "ErXwobaYiN019PkySvjV": "Male (Professional)", | |
| "TxGEqnHWrfGW9XjX": "Male (Deep)", | |
| "yoZ06aMxZJJ28mfd3POQ": "Unisex (Friendly)", | |
| "AZnzlk1XvdvUeBnXmlld": "Female (Strong)" | |
| } | |
| def get_tts_info(self): | |
| """Get TTS system information""" | |
| info = { | |
| "clients_loaded": self.clients_loaded, | |
| "advanced_tts_available": self.advanced_tts is not None, | |
| "robust_tts_available": self.robust_tts is not None, | |
| "primary_method": "Robust TTS" | |
| } | |
| try: | |
| if self.advanced_tts and hasattr(self.advanced_tts, 'get_model_info'): | |
| advanced_info = self.advanced_tts.get_model_info() | |
| info.update({ | |
| "advanced_tts_loaded": advanced_info.get("models_loaded", False), | |
| "transformers_available": advanced_info.get("transformers_available", False), | |
| "primary_method": "Facebook VITS/SpeechT5" if advanced_info.get("models_loaded") else "Robust TTS", | |
| "device": advanced_info.get("device", "cpu"), | |
| "vits_available": advanced_info.get("vits_available", False), | |
| "speecht5_available": advanced_info.get("speecht5_available", False) | |
| }) | |
| except Exception as e: | |
| logger.debug(f"Could not get advanced TTS info: {e}") | |
| return info | |
| # Import the VIDEO-FOCUSED engine | |
| try: | |
| from omniavatar_video_engine import video_engine | |
| VIDEO_ENGINE_AVAILABLE = True | |
| logger.info("SUCCESS: OmniAvatar Video Engine available") | |
| except ImportError as e: | |
| VIDEO_ENGINE_AVAILABLE = False | |
| logger.error(f"ERROR: OmniAvatar Video Engine not available: {e}") | |
| class OmniAvatarAPI: | |
| def __init__(self): | |
| self.model_loaded = False | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.tts_manager = TTSManager() | |
| logger.info(f"Using device: {self.device}") | |
| logger.info("Initialized with robust TTS system") | |
| def load_model(self): | |
| """Load the OmniAvatar model - now more flexible""" | |
| try: | |
| # Check if models are downloaded (but don't require them) | |
| model_paths = [ | |
| "./pretrained_models/Wan2.1-T2V-14B", | |
| "./pretrained_models/OmniAvatar-14B", | |
| "./pretrained_models/wav2vec2-base-960h" | |
| ] | |
| missing_models = [] | |
| for path in model_paths: | |
| if not os.path.exists(path): | |
| missing_models.append(path) | |
| if missing_models: | |
| logger.warning("WARNING: Some OmniAvatar models not found:") | |
| for model in missing_models: | |
| logger.warning(f" - {model}") | |
| logger.info("TIP: App will run in TTS-only mode (no video generation)") | |
| logger.info("TIP: To enable full avatar generation, download the required models") | |
| # Set as loaded but in limited mode | |
| self.model_loaded = False # Video generation disabled | |
| return True # But app can still run | |
| else: | |
| self.model_loaded = True | |
| logger.info("SUCCESS: All OmniAvatar models found - full functionality enabled") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error checking models: {str(e)}") | |
| logger.info("TIP: Continuing in TTS-only mode") | |
| self.model_loaded = False | |
| return True # Continue running | |
| async def download_file(self, url: str, suffix: str = "") -> str: | |
| """Download file from URL and save to temporary location""" | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(str(url)) as response: | |
| if response.status != 200: | |
| raise HTTPException(status_code=400, detail=f"Failed to download file from URL: {url}") | |
| content = await response.read() | |
| # Create temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| temp_file.write(content) | |
| temp_file.close() | |
| return temp_file.name | |
| except aiohttp.ClientError as e: | |
| logger.error(f"Network error downloading {url}: {e}") | |
| raise HTTPException(status_code=400, detail=f"Network error downloading file: {e}") | |
| except Exception as e: | |
| logger.error(f"Error downloading file from {url}: {e}") | |
| raise HTTPException(status_code=500, detail=f"Error downloading file: {e}") | |
| def validate_audio_url(self, url: str) -> bool: | |
| """Validate if URL is likely an audio file""" | |
| try: | |
| parsed = urlparse(url) | |
| # Check for common audio file extensions | |
| audio_extensions = ['.mp3', '.wav', '.m4a', '.ogg', '.aac', '.flac'] | |
| is_audio_ext = any(parsed.path.lower().endswith(ext) for ext in audio_extensions) | |
| return is_audio_ext or 'audio' in url.lower() | |
| except: | |
| return False | |
| def validate_image_url(self, url: str) -> bool: | |
| """Validate if URL is likely an image file""" | |
| try: | |
| parsed = urlparse(url) | |
| image_extensions = ['.jpg', '.jpeg', '.png', '.webp', '.bmp', '.gif'] | |
| return any(parsed.path.lower().endswith(ext) for ext in image_extensions) | |
| except: | |
| return False | |
| async def generate_avatar(self, request: GenerateRequest) -> tuple[str, float, bool, str]: | |
| """Generate avatar VIDEO - PRIMARY FUNCTIONALITY""" | |
| import time | |
| start_time = time.time() | |
| audio_generated = False | |
| method_used = "Unknown" | |
| logger.info("[VIDEO] STARTING AVATAR VIDEO GENERATION") | |
| logger.info(f"[INFO] Prompt: {request.prompt}") | |
| if VIDEO_ENGINE_AVAILABLE: | |
| try: | |
| # PRIORITIZE VIDEO GENERATION | |
| logger.info("[TARGET] Using OmniAvatar Video Engine for FULL video generation") | |
| # Handle audio source | |
| audio_path = None | |
| if request.text_to_speech: | |
| logger.info("[MIC] Generating audio from text...") | |
| audio_path, method_used = await self.tts_manager.text_to_speech( | |
| request.text_to_speech, | |
| request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
| ) | |
| audio_generated = True | |
| elif request.audio_url: | |
| logger.info("📥 Downloading audio from URL...") | |
| audio_path = await self.download_file(str(request.audio_url), ".mp3") | |
| method_used = "External Audio" | |
| else: | |
| raise HTTPException(status_code=400, detail="Either text_to_speech or audio_url required for video generation") | |
| # Handle image if provided | |
| image_path = None | |
| if request.image_url: | |
| logger.info("[IMAGE] Downloading reference image...") | |
| parsed = urlparse(str(request.image_url)) | |
| ext = os.path.splitext(parsed.path)[1] or ".jpg" | |
| image_path = await self.download_file(str(request.image_url), ext) | |
| # GENERATE VIDEO using OmniAvatar engine | |
| logger.info("[VIDEO] Generating avatar video with adaptive body animation...") | |
| video_path, generation_time = video_engine.generate_avatar_video( | |
| prompt=request.prompt, | |
| audio_path=audio_path, | |
| image_path=image_path, | |
| guidance_scale=request.guidance_scale, | |
| audio_scale=request.audio_scale, | |
| num_steps=request.num_steps | |
| ) | |
| processing_time = time.time() - start_time | |
| logger.info(f"SUCCESS: VIDEO GENERATED successfully in {processing_time:.1f}s") | |
| # Cleanup temporary files | |
| if audio_path and os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| if image_path and os.path.exists(image_path): | |
| os.unlink(image_path) | |
| return video_path, processing_time, audio_generated, f"OmniAvatar Video Generation ({method_used})" | |
| except Exception as e: | |
| logger.error(f"ERROR: Video generation failed: {e}") | |
| # For a VIDEO generation app, we should NOT fall back to audio-only | |
| # Instead, provide clear guidance | |
| if "models" in str(e).lower(): | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Video generation requires OmniAvatar models (~30GB). Please run model download script. Error: {str(e)}" | |
| ) | |
| else: | |
| raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}") | |
| # If video engine not available, this is a critical error for a VIDEO app | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Video generation engine not available. This application requires OmniAvatar models for video generation." | |
| ) | |
| async def generate_avatar_BACKUP(self, request: GenerateRequest) -> tuple[str, float, bool, str]: | |
| """OLD TTS-ONLY METHOD - kept as backup reference. | |
| Generate avatar video from prompt and audio/text - now handles missing models""" | |
| import time | |
| start_time = time.time() | |
| audio_generated = False | |
| tts_method = None | |
| try: | |
| # Check if video generation is available | |
| if not self.model_loaded: | |
| logger.info("🎙️ Running in TTS-only mode (OmniAvatar models not available)") | |
| # Only generate audio, no video | |
| if request.text_to_speech: | |
| logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") | |
| audio_path, tts_method = await self.tts_manager.text_to_speech( | |
| request.text_to_speech, | |
| request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
| ) | |
| # Return the audio file as the "output" | |
| processing_time = time.time() - start_time | |
| logger.info(f"SUCCESS: TTS completed in {processing_time:.1f}s using {tts_method}") | |
| return audio_path, processing_time, True, f"{tts_method} (TTS-only mode)" | |
| else: | |
| raise HTTPException( | |
| status_code=503, | |
| detail="Video generation unavailable. OmniAvatar models not found. Only TTS from text is supported." | |
| ) | |
| # Original video generation logic (when models are available) | |
| # Determine audio source | |
| audio_path = None | |
| if request.text_to_speech: | |
| # Generate speech from text using TTS manager | |
| logger.info(f"Generating speech from text: {request.text_to_speech[:50]}...") | |
| audio_path, tts_method = await self.tts_manager.text_to_speech( | |
| request.text_to_speech, | |
| request.voice_id or "21m00Tcm4TlvDq8ikWAM" | |
| ) | |
| audio_generated = True | |
| elif request.audio_url: | |
| # Download audio from provided URL | |
| logger.info(f"Downloading audio from URL: {request.audio_url}") | |
| if not self.validate_audio_url(str(request.audio_url)): | |
| logger.warning(f"Audio URL may not be valid: {request.audio_url}") | |
| audio_path = await self.download_file(str(request.audio_url), ".mp3") | |
| tts_method = "External Audio URL" | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Either text_to_speech or audio_url must be provided" | |
| ) | |
| # Download image if provided | |
| image_path = None | |
| if request.image_url: | |
| logger.info(f"Downloading image from URL: {request.image_url}") | |
| if not self.validate_image_url(str(request.image_url)): | |
| logger.warning(f"Image URL may not be valid: {request.image_url}") | |
| # Determine image extension from URL or default to .jpg | |
| parsed = urlparse(str(request.image_url)) | |
| ext = os.path.splitext(parsed.path)[1] or ".jpg" | |
| image_path = await self.download_file(str(request.image_url), ext) | |
| # Create temporary input file for inference | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: | |
| if image_path: | |
| input_line = f"{request.prompt}@@{image_path}@@{audio_path}" | |
| else: | |
| input_line = f"{request.prompt}@@@@{audio_path}" | |
| f.write(input_line) | |
| temp_input_file = f.name | |
| # Prepare inference command | |
| cmd = [ | |
| "python", "-m", "torch.distributed.run", | |
| "--standalone", f"--nproc_per_node={request.sp_size}", | |
| "scripts/inference.py", | |
| "--config", "configs/inference.yaml", | |
| "--input_file", temp_input_file, | |
| "--guidance_scale", str(request.guidance_scale), | |
| "--audio_scale", str(request.audio_scale), | |
| "--num_steps", str(request.num_steps) | |
| ] | |
| if request.tea_cache_l1_thresh: | |
| cmd.extend(["--tea_cache_l1_thresh", str(request.tea_cache_l1_thresh)]) | |
| logger.info(f"Running inference with command: {' '.join(cmd)}") | |
| # Run inference | |
| result = subprocess.run(cmd, capture_output=True, text=True) | |
| # Clean up temporary files | |
| os.unlink(temp_input_file) | |
| os.unlink(audio_path) | |
| if image_path: | |
| os.unlink(image_path) | |
| if result.returncode != 0: | |
| logger.error(f"Inference failed: {result.stderr}") | |
| raise Exception(f"Inference failed: {result.stderr}") | |
| # Find output video file | |
| output_dir = "./outputs" | |
| if os.path.exists(output_dir): | |
| video_files = [f for f in os.listdir(output_dir) if f.endswith(('.mp4', '.avi'))] | |
| if video_files: | |
| # Return the most recent video file | |
| video_files.sort(key=lambda x: os.path.getmtime(os.path.join(output_dir, x)), reverse=True) | |
| output_path = os.path.join(output_dir, video_files[0]) | |
| processing_time = time.time() - start_time | |
| return output_path, processing_time, audio_generated, tts_method | |
| raise Exception("No output video generated") | |
| except Exception as e: | |
| # Clean up any temporary files in case of error | |
| try: | |
| if 'audio_path' in locals() and audio_path and os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| if 'image_path' in locals() and image_path and os.path.exists(image_path): | |
| os.unlink(image_path) | |
| if 'temp_input_file' in locals() and os.path.exists(temp_input_file): | |
| os.unlink(temp_input_file) | |
| except: | |
| pass | |
| logger.error(f"Generation error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Initialize API | |
| omni_api = OmniAvatarAPI() | |
| # Use FastAPI lifespan instead of deprecated on_event | |
| from contextlib import asynccontextmanager | |
| async def lifespan(app: FastAPI): | |
| # Startup | |
| success = omni_api.load_model() | |
| if not success: | |
| logger.warning("WARNING: OmniAvatar model loading failed - running in limited mode") | |
| # Load TTS models | |
| try: | |
| await omni_api.tts_manager.load_models() | |
| logger.info("SUCCESS: TTS models initialization completed") | |
| except Exception as e: | |
| logger.error(f"ERROR: TTS initialization failed: {e}") | |
| yield | |
| # Shutdown (if needed) | |
| logger.info("Application shutting down...") | |
| # Create FastAPI app WITH lifespan parameter | |
| app = FastAPI( | |
| title="OmniAvatar-14B API with Advanced TTS", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Mount static files for serving generated videos | |
| app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| tts_info = omni_api.tts_manager.get_tts_info() | |
| return { | |
| "status": "healthy", | |
| "model_loaded": omni_api.model_loaded, | |
| "video_generation_available": omni_api.model_loaded, | |
| "tts_only_mode": not omni_api.model_loaded, | |
| "device": omni_api.device, | |
| "supports_text_to_speech": True, | |
| "supports_image_urls": omni_api.model_loaded, | |
| "supports_audio_urls": omni_api.model_loaded, | |
| "tts_system": "Advanced TTS with Robust Fallback", | |
| "advanced_tts_available": ADVANCED_TTS_AVAILABLE, | |
| "robust_tts_available": ROBUST_TTS_AVAILABLE, | |
| **tts_info | |
| } | |
| async def get_voices(): | |
| """Get available voice configurations""" | |
| try: | |
| voices = await omni_api.tts_manager.get_available_voices() | |
| return {"voices": voices} | |
| except Exception as e: | |
| logger.error(f"Error getting voices: {e}") | |
| return {"error": str(e)} | |
| async def generate_avatar(request: GenerateRequest): | |
| """Generate avatar video from prompt, text/audio, and optional image URL""" | |
| logger.info(f"Generating avatar with prompt: {request.prompt}") | |
| if request.text_to_speech: | |
| logger.info(f"Text to speech: {request.text_to_speech[:100]}...") | |
| logger.info(f"Voice ID: {request.voice_id}") | |
| if request.audio_url: | |
| logger.info(f"Audio URL: {request.audio_url}") | |
| if request.image_url: | |
| logger.info(f"Image URL: {request.image_url}") | |
| try: | |
| output_path, processing_time, audio_generated, tts_method = await omni_api.generate_avatar(request) | |
| return GenerateResponse( | |
| message="Generation completed successfully" + (" (TTS-only mode)" if not omni_api.model_loaded else ""), | |
| output_path=get_video_url(output_path) if omni_api.model_loaded else output_path, | |
| processing_time=processing_time, | |
| audio_generated=audio_generated, | |
| tts_method=tts_method | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Unexpected error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Unexpected error: {e}") | |
| # Enhanced Gradio interface | |
| def gradio_generate(prompt, text_to_speech, audio_url, image_url, voice_id, guidance_scale, audio_scale, num_steps): | |
| """Gradio interface wrapper with robust TTS support""" | |
| try: | |
| # Create request object | |
| request_data = { | |
| "prompt": prompt, | |
| "guidance_scale": guidance_scale, | |
| "audio_scale": audio_scale, | |
| "num_steps": int(num_steps) | |
| } | |
| # Add audio source | |
| if text_to_speech and text_to_speech.strip(): | |
| request_data["text_to_speech"] = text_to_speech | |
| request_data["voice_id"] = voice_id or "21m00Tcm4TlvDq8ikWAM" | |
| elif audio_url and audio_url.strip(): | |
| if omni_api.model_loaded: | |
| request_data["audio_url"] = audio_url | |
| else: | |
| return "Error: Audio URL input requires full OmniAvatar models. Please use text-to-speech instead." | |
| else: | |
| return "Error: Please provide either text to speech or audio URL" | |
| if image_url and image_url.strip(): | |
| if omni_api.model_loaded: | |
| request_data["image_url"] = image_url | |
| else: | |
| return "Error: Image URL input requires full OmniAvatar models for video generation." | |
| request = GenerateRequest(**request_data) | |
| # Run async function in sync context | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| output_path, processing_time, audio_generated, tts_method = loop.run_until_complete(omni_api.generate_avatar(request)) | |
| loop.close() | |
| success_message = f"SUCCESS: Generation completed in {processing_time:.1f}s using {tts_method}" | |
| print(success_message) | |
| if omni_api.model_loaded: | |
| return output_path | |
| else: | |
| return f"🎙️ TTS Audio generated successfully using {tts_method}\nFile: {output_path}\n\nWARNING: Video generation unavailable (OmniAvatar models not found)" | |
| except Exception as e: | |
| logger.error(f"Gradio generation error: {e}") | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| mode_info = " (TTS-Only Mode)" if not omni_api.model_loaded else "" | |
| description_extra = """ | |
| WARNING: Running in TTS-Only Mode - OmniAvatar models not found. Only text-to-speech generation is available. | |
| To enable full video generation, the required model files need to be downloaded. | |
| """ if not omni_api.model_loaded else "" | |
| iface = gr.Interface( | |
| fn=gradio_generate, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the character behavior (e.g., 'A friendly person explaining a concept')", | |
| lines=2 | |
| ), | |
| gr.Textbox( | |
| label="Text to Speech", | |
| placeholder="Enter text to convert to speech", | |
| lines=3, | |
| info="Will use best available TTS system (Advanced or Fallback)" | |
| ), | |
| gr.Textbox( | |
| label="OR Audio URL", | |
| placeholder="https://example.com/audio.mp3", | |
| info="Direct URL to audio file (requires full models)" if not omni_api.model_loaded else "Direct URL to audio file" | |
| ), | |
| gr.Textbox( | |
| label="Image URL (Optional)", | |
| placeholder="https://example.com/image.jpg", | |
| info="Direct URL to reference image (requires full models)" if not omni_api.model_loaded else "Direct URL to reference image" | |
| ), | |
| gr.Dropdown( | |
| choices=[ | |
| "21m00Tcm4TlvDq8ikWAM", | |
| "pNInz6obpgDQGcFmaJgB", | |
| "EXAVITQu4vr4xnSDxMaL", | |
| "ErXwobaYiN019PkySvjV", | |
| "TxGEqnHWrfGW9XjX", | |
| "yoZ06aMxZJJ28mfd3POQ", | |
| "AZnzlk1XvdvUeBnXmlld" | |
| ], | |
| value="21m00Tcm4TlvDq8ikWAM", | |
| label="Voice Profile", | |
| info="Choose voice characteristics for TTS generation" | |
| ), | |
| gr.Slider(minimum=1, maximum=10, value=5.0, label="Guidance Scale", info="4-6 recommended"), | |
| gr.Slider(minimum=1, maximum=10, value=3.0, label="Audio Scale", info="Higher values = better lip-sync"), | |
| gr.Slider(minimum=10, maximum=100, value=30, step=1, label="Number of Steps", info="20-50 recommended") | |
| ], | |
| outputs=gr.Video(label="Generated Avatar Video") if omni_api.model_loaded else gr.Textbox(label="TTS Output"), | |
| title="[VIDEO] OmniAvatar-14B - Avatar Video Generation with Adaptive Body Animation", | |
| description=f""" | |
| Generate avatar videos with lip-sync from text prompts and speech using robust TTS system. | |
| {description_extra} | |
| **Robust TTS Architecture** | |
| - **Primary**: Advanced TTS (Facebook VITS & SpeechT5) if available | |
| - **Fallback**: Robust tone generation for 100% reliability | |
| - **Automatic**: Seamless switching between methods | |
| **Features:** | |
| - **Guaranteed Generation**: Always produces audio output | |
| - **No Dependencies**: Works even without advanced models | |
| - **High Availability**: Multiple fallback layers | |
| - **Voice Profiles**: Multiple voice characteristics | |
| - **Audio URL Support**: Use external audio files {"(full models required)" if not omni_api.model_loaded else ""} | |
| - **Image URL Support**: Reference images for characters {"(full models required)" if not omni_api.model_loaded else ""} | |
| **Usage:** | |
| 1. Enter a character description in the prompt | |
| 2. **Enter text for speech generation** (recommended in current mode) | |
| 3. {"Optionally add reference image/audio URLs (requires full models)" if not omni_api.model_loaded else "Optionally add reference image URL and choose audio source"} | |
| 4. Choose voice profile and adjust parameters | |
| 5. Generate your {"audio" if not omni_api.model_loaded else "avatar video"}! | |
| """, | |
| examples=[ | |
| [ | |
| "A professional teacher explaining a mathematical concept with clear gestures", | |
| "Hello students! Today we're going to learn about calculus and derivatives.", | |
| "", | |
| "", | |
| "21m00Tcm4TlvDq8ikWAM", | |
| 5.0, | |
| 3.5, | |
| 30 | |
| ], | |
| [ | |
| "A friendly presenter speaking confidently to an audience", | |
| "Welcome everyone to our presentation on artificial intelligence!", | |
| "", | |
| "", | |
| "pNInz6obpgDQGcFmaJgB", | |
| 5.5, | |
| 4.0, | |
| 35 | |
| ] | |
| ], | |
| allow_flagging="never", | |
| flagging_dir="/tmp/gradio_flagged" | |
| ) | |
| # Mount Gradio app | |
| app = gr.mount_gradio_app(app, iface, path="/gradio") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |