talkingAvater_bgk / core /optimization /parallel_inference.py
oKen38461's picture
推論キャッシュと並列処理の機能を追加し、`process_talking_head_optimized`関数をキャッシュと並列処理に対応させました。また、Gradioインターフェースにキャッシュ管理機能を追加しました。
07b71bb
"""
Parallel Inference Integration for DittoTalkingHead
Integrates parallel processing into the inference pipeline
"""
import asyncio
import time
from typing import Dict, Any, Tuple, Optional
import numpy as np
import torch
from pathlib import Path
from .parallel_processing import ParallelProcessor, PipelineProcessor
class ParallelInference:
"""
Parallel inference wrapper for DittoTalkingHead
"""
def __init__(self, sdk, parallel_processor: Optional[ParallelProcessor] = None):
"""
Initialize parallel inference
Args:
sdk: StreamSDK instance
parallel_processor: ParallelProcessor instance
"""
self.sdk = sdk
self.parallel_processor = parallel_processor or ParallelProcessor(num_threads=4)
# Setup pipeline stages
self.pipeline_stages = {
'load': self._load_files,
'preprocess': self._preprocess,
'inference': self._inference,
'postprocess': self._postprocess
}
def _load_files(self, paths: Dict[str, str]) -> Dict[str, Any]:
"""Load audio and image files"""
audio_path = paths['audio']
image_path = paths['image']
# Parallel loading
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync(
audio_path, image_path
)
return {
'audio_data': audio_data,
'image_data': image_data,
'paths': paths
}
def _preprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Preprocess loaded data"""
# Extract audio features
audio = data['audio_data']['audio']
sr = data['audio_data']['sample_rate']
# Prepare for SDK
import librosa
import math
# Calculate number of frames
num_frames = math.ceil(len(audio) / 16000 * 25)
# Prepare image
image = data['image_data']['image']
return {
'audio': audio,
'image': image,
'num_frames': num_frames,
'paths': data['paths']
}
def _inference(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Run inference"""
# This would integrate with the actual SDK inference
# For now, placeholder
return {
'result': 'inference_result',
'paths': data['paths']
}
def _postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""Postprocess results"""
return data
async def process_parallel_async(
self,
audio_path: str,
image_path: str,
output_path: str,
**kwargs
) -> Tuple[str, float]:
"""
Process with full parallelization (async)
Args:
audio_path: Path to audio file
image_path: Path to image file
output_path: Output video path
**kwargs: Additional parameters
Returns:
Tuple of (output_path, process_time)
"""
start_time = time.time()
# Parallel preprocessing
audio_data, image_data = await self.parallel_processor.preprocess_parallel_async(
audio_path, image_path, kwargs.get('target_size', 320)
)
# Run inference (simplified for integration)
# In real implementation, this would call SDK methods
process_time = time.time() - start_time
return output_path, process_time
def process_parallel_sync(
self,
audio_path: str,
image_path: str,
output_path: str,
**kwargs
) -> Tuple[str, float]:
"""
Process with parallelization (sync)
Args:
audio_path: Path to audio file
image_path: Path to image file
output_path: Output video path
**kwargs: Additional parameters
Returns:
Tuple of (output_path, process_time)
"""
start_time = time.time()
try:
# Parallel preprocessing
print("🔄 Starting parallel preprocessing...")
preprocess_start = time.time()
audio_data, image_data = self.parallel_processor.preprocess_parallel_sync(
audio_path, image_path, kwargs.get('target_size', 320)
)
preprocess_time = time.time() - preprocess_start
print(f"✅ Parallel preprocessing completed in {preprocess_time:.2f}s")
# Run actual SDK inference
# This integrates with the existing SDK
from inference import run, seed_everything
seed_everything(kwargs.get('seed', 1024))
inference_start = time.time()
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {}))
inference_time = time.time() - inference_start
print(f"✅ Inference completed in {inference_time:.2f}s")
total_time = time.time() - start_time
# Performance breakdown
print(f"""
🎯 Performance Breakdown:
- Preprocessing (parallel): {preprocess_time:.2f}s
- Inference: {inference_time:.2f}s
- Total: {total_time:.2f}s
""")
return output_path, total_time
except Exception as e:
print(f"❌ Error in parallel processing: {e}")
raise
def get_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics"""
return {
'num_threads': self.parallel_processor.num_threads,
'num_processes': self.parallel_processor.num_processes,
'cuda_streams_enabled': self.parallel_processor.use_cuda_streams
}
class OptimizedInferenceWrapper:
"""
Wrapper that combines all optimizations
"""
def __init__(
self,
sdk,
use_parallel: bool = True,
use_cache: bool = True,
use_gpu_opt: bool = True
):
"""
Initialize optimized inference wrapper
Args:
sdk: StreamSDK instance
use_parallel: Enable parallel processing
use_cache: Enable caching
use_gpu_opt: Enable GPU optimizations
"""
self.sdk = sdk
self.use_parallel = use_parallel
self.use_cache = use_cache
self.use_gpu_opt = use_gpu_opt
# Initialize components
if use_parallel:
self.parallel_processor = ParallelProcessor(num_threads=4)
self.parallel_inference = ParallelInference(sdk, self.parallel_processor)
else:
self.parallel_processor = None
self.parallel_inference = None
def process(
self,
audio_path: str,
image_path: str,
output_path: str,
**kwargs
) -> Tuple[str, float, Dict[str, Any]]:
"""
Process with all optimizations
Returns:
Tuple of (output_path, process_time, stats)
"""
stats = {
'parallel_enabled': self.use_parallel,
'cache_enabled': self.use_cache,
'gpu_opt_enabled': self.use_gpu_opt
}
if self.use_parallel and self.parallel_inference:
output_path, process_time = self.parallel_inference.process_parallel_sync(
audio_path, image_path, output_path, **kwargs
)
stats['preprocessing'] = 'parallel'
else:
# Fallback to sequential
from inference import run, seed_everything
start_time = time.time()
seed_everything(kwargs.get('seed', 1024))
run(self.sdk, audio_path, image_path, output_path, more_kwargs=kwargs.get('more_kwargs', {}))
process_time = time.time() - start_time
stats['preprocessing'] = 'sequential'
stats['process_time'] = process_time
return output_path, process_time, stats
def shutdown(self):
"""Cleanup resources"""
if self.parallel_processor:
self.parallel_processor.shutdown()