Spaces:
Runtime error
Runtime error
""" | |
Parallel Inference Integration for DittoTalkingHead | |
Integrates parallel processing into the inference pipeline | |
""" | |
import asyncio | |
import time | |
from typing import Dict, Any, Tuple, Optional | |
import numpy as np | |
import torch | |
from pathlib import Path | |
from .parallel_processing import ParallelProcessor, PipelineProcessor | |
class ParallelInference: | |
""" | |
Parallel inference wrapper for DittoTalkingHead | |
""" | |
def __init__(self, sdk, parallel_processor: Optional[ParallelProcessor] = None): | |
""" | |
Initialize parallel inference | |
Args: | |
sdk: StreamSDK instance | |
parallel_processor: ParallelProcessor instance | |
""" | |
self.sdk = sdk | |
self.parallel_processor = parallel_processor or ParallelProcessor(num_threads=4) | |
# Setup pipeline stages | |
self.pipeline_stages = { | |
'load': self._load_files, | |
'preprocess': self._preprocess, | |
'inference': self._inference, | |
'postprocess': self._postprocess | |
} | |
def _load_files(self, paths: Dict[str, str]) -> Dict[str, Any]: | |
"""Load audio and image files""" | |
audio_path = paths['audio'] | |
image_path = paths['image'] | |
# Parallel loading | |
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync( | |
audio_path, image_path | |
) | |
return { | |
'audio_data': audio_data, | |
'image_data': image_data, | |
'paths': paths | |
} | |
def _preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
"""Preprocess loaded data""" | |
# Extract audio features | |
audio = data['audio_data']['audio'] | |
sr = data['audio_data']['sample_rate'] | |
# Prepare for SDK | |
import librosa | |
import math | |
# Calculate number of frames | |
num_frames = math.ceil(len(audio) / 16000 * 25) | |
# Prepare image | |
image = data['image_data']['image'] | |
return { | |
'audio': audio, | |
'image': image, | |
'num_frames': num_frames, | |
'paths': data['paths'] | |
} | |
def _inference(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
"""Run inference""" | |
# This would integrate with the actual SDK inference | |
# For now, placeholder | |
return { | |
'result': 'inference_result', | |
'paths': data['paths'] | |
} | |
def _postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: | |
"""Postprocess results""" | |
return data | |
async def process_parallel_async( | |
self, | |
audio_path: str, | |
image_path: str, | |
output_path: str, | |
**kwargs | |
) -> Tuple[str, float]: | |
""" | |
Process with full parallelization (async) | |
Args: | |
audio_path: Path to audio file | |
image_path: Path to image file | |
output_path: Output video path | |
**kwargs: Additional parameters | |
Returns: | |
Tuple of (output_path, process_time) | |
""" | |
start_time = time.time() | |
# Parallel preprocessing | |
audio_data, image_data = await self.parallel_processor.preprocess_parallel_async( | |
audio_path, image_path, kwargs.get('target_size', 320) | |
) | |
# Run inference (simplified for integration) | |
# In real implementation, this would call SDK methods | |
process_time = time.time() - start_time | |
return output_path, process_time | |
def process_parallel_sync( | |
self, | |
audio_path: str, | |
image_path: str, | |
output_path: str, | |
**kwargs | |
) -> Tuple[str, float]: | |
""" | |
Process with parallelization (sync) | |
Args: | |
audio_path: Path to audio file | |
image_path: Path to image file | |
output_path: Output video path | |
**kwargs: Additional parameters | |
Returns: | |
Tuple of (output_path, process_time) | |
""" | |
start_time = time.time() | |
try: | |
# Parallel preprocessing | |
print("🔄 Starting parallel preprocessing...") | |
preprocess_start = time.time() | |
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync( | |
audio_path, image_path, kwargs.get('target_size', 320) | |
) | |
preprocess_time = time.time() - preprocess_start | |
print(f"✅ Parallel preprocessing completed in {preprocess_time:.2f}s") | |
# Run actual SDK inference | |
# This integrates with the existing SDK | |
from inference import run, seed_everything | |
seed_everything(kwargs.get('seed', 1024)) | |
inference_start = time.time() | |
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {})) | |
inference_time = time.time() - inference_start | |
print(f"✅ Inference completed in {inference_time:.2f}s") | |
total_time = time.time() - start_time | |
# Performance breakdown | |
print(f""" | |
🎯 Performance Breakdown: | |
- Preprocessing (parallel): {preprocess_time:.2f}s | |
- Inference: {inference_time:.2f}s | |
- Total: {total_time:.2f}s | |
""") | |
return output_path, total_time | |
except Exception as e: | |
print(f"❌ Error in parallel processing: {e}") | |
raise | |
def get_performance_stats(self) -> Dict[str, Any]: | |
"""Get performance statistics""" | |
return { | |
'num_threads': self.parallel_processor.num_threads, | |
'num_processes': self.parallel_processor.num_processes, | |
'cuda_streams_enabled': self.parallel_processor.use_cuda_streams | |
} | |
class OptimizedInferenceWrapper: | |
""" | |
Wrapper that combines all optimizations | |
""" | |
def __init__( | |
self, | |
sdk, | |
use_parallel: bool = True, | |
use_cache: bool = True, | |
use_gpu_opt: bool = True | |
): | |
""" | |
Initialize optimized inference wrapper | |
Args: | |
sdk: StreamSDK instance | |
use_parallel: Enable parallel processing | |
use_cache: Enable caching | |
use_gpu_opt: Enable GPU optimizations | |
""" | |
self.sdk = sdk | |
self.use_parallel = use_parallel | |
self.use_cache = use_cache | |
self.use_gpu_opt = use_gpu_opt | |
# Initialize components | |
if use_parallel: | |
self.parallel_processor = ParallelProcessor(num_threads=4) | |
self.parallel_inference = ParallelInference(sdk, self.parallel_processor) | |
else: | |
self.parallel_processor = None | |
self.parallel_inference = None | |
def process( | |
self, | |
audio_path: str, | |
image_path: str, | |
output_path: str, | |
**kwargs | |
) -> Tuple[str, float, Dict[str, Any]]: | |
""" | |
Process with all optimizations | |
Returns: | |
Tuple of (output_path, process_time, stats) | |
""" | |
stats = { | |
'parallel_enabled': self.use_parallel, | |
'cache_enabled': self.use_cache, | |
'gpu_opt_enabled': self.use_gpu_opt | |
} | |
if self.use_parallel and self.parallel_inference: | |
output_path, process_time = self.parallel_inference.process_parallel_sync( | |
audio_path, image_path, output_path, **kwargs | |
) | |
stats['preprocessing'] = 'parallel' | |
else: | |
# Fallback to sequential | |
from inference import run, seed_everything | |
start_time = time.time() | |
seed_everything(kwargs.get('seed', 1024)) | |
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {})) | |
process_time = time.time() - start_time | |
stats['preprocessing'] = 'sequential' | |
stats['process_time'] = process_time | |
return output_path, process_time, stats | |
def shutdown(self): | |
"""Cleanup resources""" | |
if self.parallel_processor: | |
self.parallel_processor.shutdown() |