MogensR's picture
Create api/pipeline.py
ca7243a
raw
history blame
28.5 kB
"""
Main processing pipeline for BackgroundFX Pro.
Orchestrates the complete background removal and replacement workflow.
"""
import cv2
import numpy as np
import torch
from typing import Dict, List, Optional, Tuple, Union, Callable, Any
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
import time
import threading
from queue import Queue
import json
import hashlib
from concurrent.futures import ThreadPoolExecutor, Future
from ..utils.logger import setup_logger
from ..utils.device import DeviceManager
from ..utils.config import ConfigManager
from ..utils import TimeEstimator, MemoryMonitor
from ..core.models import ModelFactory, ModelType
from ..core.temporal import TemporalCoherence
from ..core.quality import QualityAnalyzer
from ..core.edge import EdgeRefinement
from ..core.hair_segmentation import HairSegmentation
from ..processing.matting import AlphaMatting, MattingConfig, CompositingEngine
from ..processing.fallback import FallbackStrategy, FallbackLevel
from ..processing.effects import BackgroundEffects, CompositeEffects, EffectType
logger = setup_logger(__name__)
class ProcessingMode(Enum):
"""Processing mode types."""
PHOTO = "photo"
VIDEO = "video"
REALTIME = "realtime"
BATCH = "batch"
class PipelineStage(Enum):
"""Pipeline processing stages."""
INITIALIZATION = "initialization"
PREPROCESSING = "preprocessing"
SEGMENTATION = "segmentation"
MATTING = "matting"
REFINEMENT = "refinement"
EFFECTS = "effects"
COMPOSITING = "compositing"
POSTPROCESSING = "postprocessing"
COMPLETE = "complete"
@dataclass
class PipelineConfig:
"""Configuration for the processing pipeline."""
# Model settings
model_type: ModelType = ModelType.RMBG_1_4
use_gpu: bool = True
device: Optional[str] = None
# Processing settings
mode: ProcessingMode = ProcessingMode.PHOTO
enable_temporal: bool = True
enable_hair_refinement: bool = True
enable_edge_refinement: bool = True
enable_fallback: bool = True
# Quality settings
quality_preset: str = "high" # low, medium, high, ultra
target_resolution: Optional[Tuple[int, int]] = None
maintain_aspect_ratio: bool = True
# Matting settings
matting_method: str = "auto" # auto, trimap, deep, guided
matting_config: MattingConfig = field(default_factory=MattingConfig)
# Effects settings
background_blur: bool = False
blur_strength: float = 15.0
apply_effects: List[EffectType] = field(default_factory=list)
# Performance settings
batch_size: int = 1
num_workers: int = 4
enable_caching: bool = True
cache_size_mb: int = 500
# Output settings
output_format: str = "png" # png, jpg, webp
output_quality: int = 95
preserve_metadata: bool = True
# Callbacks
progress_callback: Optional[Callable[[float, str], None]] = None
stage_callback: Optional[Callable[[PipelineStage, Dict], None]] = None
@dataclass
class PipelineResult:
"""Result from pipeline processing."""
success: bool
output_image: Optional[np.ndarray] = None
alpha_matte: Optional[np.ndarray] = None
foreground: Optional[np.ndarray] = None
background: Optional[np.ndarray] = None
metadata: Dict[str, Any] = field(default_factory=dict)
processing_time: float = 0.0
stages_completed: List[PipelineStage] = field(default_factory=list)
errors: List[str] = field(default_factory=list)
quality_score: float = 0.0
class ProcessingPipeline:
"""Main processing pipeline orchestrator."""
def __init__(self, config: Optional[PipelineConfig] = None):
"""
Initialize the processing pipeline.
Args:
config: Pipeline configuration
"""
self.config = config or PipelineConfig()
self.logger = setup_logger(f"{__name__}.ProcessingPipeline")
# Initialize components
self._initialize_components()
# State management
self.current_stage = PipelineStage.INITIALIZATION
self.processing_stats = {}
self.cache = {}
self.is_processing = False
# Thread pool for parallel processing
self.executor = ThreadPoolExecutor(max_workers=self.config.num_workers)
self.logger.info("Pipeline initialized successfully")
def _initialize_components(self):
"""Initialize all pipeline components."""
try:
# Device management
self.device_manager = DeviceManager()
if self.config.device:
self.device_manager.set_device(self.config.device)
elif not self.config.use_gpu:
self.device_manager.set_device('cpu')
# Core components
self.model_factory = ModelFactory()
self.quality_analyzer = QualityAnalyzer()
self.edge_refinement = EdgeRefinement()
self.temporal_coherence = TemporalCoherence() if self.config.enable_temporal else None
self.hair_segmentation = HairSegmentation() if self.config.enable_hair_refinement else None
# Processing components
self.alpha_matting = AlphaMatting(self.config.matting_config)
self.compositing_engine = CompositingEngine()
self.background_effects = BackgroundEffects()
self.composite_effects = CompositeEffects()
# Fallback strategy
self.fallback_strategy = FallbackStrategy() if self.config.enable_fallback else None
# Memory monitoring
self.memory_monitor = MemoryMonitor()
self.time_estimator = TimeEstimator()
# Load model
self._load_model()
except Exception as e:
self.logger.error(f"Component initialization failed: {e}")
raise
def _load_model(self):
"""Load the segmentation model."""
try:
self.logger.info(f"Loading model: {self.config.model_type.value}")
self.model = self.model_factory.load_model(
self.config.model_type,
device=self.device_manager.get_device(),
optimize=True
)
self.logger.info("Model loaded successfully")
except Exception as e:
self.logger.error(f"Model loading failed: {e}")
if self.config.enable_fallback:
self.logger.info("Attempting fallback model loading")
self.config.model_type = ModelType.U2NET_LITE
self.model = self.model_factory.load_model(
self.config.model_type,
device='cpu'
)
def process_image(self,
image: Union[np.ndarray, str, Path],
background: Optional[Union[np.ndarray, str, Path]] = None,
**kwargs) -> PipelineResult:
"""
Process a single image through the pipeline.
Args:
image: Input image (array or path)
background: Optional background image/path
**kwargs: Additional processing parameters
Returns:
PipelineResult with processed image and metadata
"""
start_time = time.time()
self.is_processing = True
result = PipelineResult(success=False)
try:
# Stage 1: Initialization
self._update_stage(PipelineStage.INITIALIZATION)
image_array = self._load_image(image)
bg_array = self._load_image(background) if background is not None else None
# Generate cache key
cache_key = self._generate_cache_key(image_array, kwargs)
# Check cache
if self.config.enable_caching and cache_key in self.cache:
self.logger.info("Using cached result")
cached_result = self.cache[cache_key]
cached_result.processing_time = time.time() - start_time
return cached_result
# Stage 2: Preprocessing
self._update_stage(PipelineStage.PREPROCESSING)
preprocessed = self._preprocess_image(image_array)
result.metadata['original_size'] = image_array.shape[:2]
result.metadata['preprocessed_size'] = preprocessed.shape[:2]
# Quality analysis
quality_metrics = self.quality_analyzer.analyze_frame(preprocessed)
result.metadata['quality_metrics'] = quality_metrics
# Stage 3: Segmentation
self._update_stage(PipelineStage.SEGMENTATION)
segmentation_mask = self._segment_image(preprocessed)
# Hair refinement if enabled
if self.config.enable_hair_refinement:
self.logger.info("Applying hair refinement")
hair_mask = self.hair_segmentation.segment_hair(preprocessed)
segmentation_mask = self._combine_masks(segmentation_mask, hair_mask)
# Stage 4: Matting
self._update_stage(PipelineStage.MATTING)
matting_result = self.alpha_matting.process(
preprocessed,
segmentation_mask,
method=self.config.matting_method
)
alpha_matte = matting_result['alpha']
result.metadata['matting_confidence'] = matting_result['confidence']
# Stage 5: Refinement
self._update_stage(PipelineStage.REFINEMENT)
if self.config.enable_edge_refinement:
alpha_matte = self.edge_refinement.refine_edges(
preprocessed,
(alpha_matte * 255).astype(np.uint8)
) / 255.0
# Resize alpha to original size if needed
if preprocessed.shape[:2] != image_array.shape[:2]:
alpha_matte = cv2.resize(
alpha_matte,
(image_array.shape[1], image_array.shape[0]),
interpolation=cv2.INTER_LINEAR
)
# Extract foreground
foreground = self._extract_foreground(image_array, alpha_matte)
# Stage 6: Background & Effects
self._update_stage(PipelineStage.EFFECTS)
if bg_array is not None:
# Resize background to match image
bg_array = self._resize_background(bg_array, image_array.shape[:2])
# Apply background effects
if self.config.background_blur:
bg_array = self.background_effects.apply_blur(
bg_array,
strength=self.config.blur_strength,
mask=1 - alpha_matte
)
# Apply configured effects
if self.config.apply_effects:
bg_array = self._apply_effects(bg_array, alpha_matte)
else:
# Create transparent background
bg_array = np.zeros_like(image_array)
# Stage 7: Compositing
self._update_stage(PipelineStage.COMPOSITING)
if self.config.apply_effects and EffectType.LIGHT_WRAP in self.config.apply_effects:
foreground = self.background_effects.apply_light_wrap(
foreground, bg_array, alpha_matte
)
composited = self.compositing_engine.composite(
foreground, bg_array, alpha_matte
)
# Apply post-composite effects
if self.config.apply_effects:
composited = self._apply_post_effects(composited, alpha_matte)
# Stage 8: Postprocessing
self._update_stage(PipelineStage.POSTPROCESSING)
final_output = self._postprocess_image(composited, alpha_matte)
# Calculate quality score
result.quality_score = self._calculate_quality_score(
final_output, alpha_matte, quality_metrics
)
# Build result
result.success = True
result.output_image = final_output
result.alpha_matte = alpha_matte
result.foreground = foreground
result.background = bg_array
result.stages_completed = list(PipelineStage)
result.processing_time = time.time() - start_time
# Cache result
if self.config.enable_caching:
self._cache_result(cache_key, result)
# Complete
self._update_stage(PipelineStage.COMPLETE)
self.logger.info(f"Processing completed in {result.processing_time:.2f}s")
# Update statistics
self._update_statistics(result)
except Exception as e:
self.logger.error(f"Pipeline processing failed: {e}")
result.errors.append(str(e))
if self.config.enable_fallback and self.fallback_strategy:
self.logger.info("Attempting fallback processing")
result = self._fallback_processing(image_array, bg_array)
finally:
self.is_processing = False
return result
def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
"""Preprocess image for optimal processing."""
processed = image.copy()
# Resize if needed
if self.config.target_resolution:
target_h, target_w = self.config.target_resolution
h, w = image.shape[:2]
if self.config.maintain_aspect_ratio:
scale = min(target_w / w, target_h / h)
new_w = int(w * scale)
new_h = int(h * scale)
else:
new_w, new_h = target_w, target_h
if (new_w, new_h) != (w, h):
processed = cv2.resize(processed, (new_w, new_h),
interpolation=cv2.INTER_AREA)
# Apply quality-based preprocessing
if self.config.quality_preset == "low":
# Reduce noise for faster processing
processed = cv2.fastNlMeansDenoising(processed, None, 10, 7, 21)
elif self.config.quality_preset in ["high", "ultra"]:
# Enhance details
processed = cv2.detailEnhance(processed, sigma_s=10, sigma_r=0.15)
return processed
def _segment_image(self, image: np.ndarray) -> np.ndarray:
"""Perform image segmentation."""
try:
# Use the loaded model for segmentation
with torch.no_grad():
# Prepare input
input_tensor = self._prepare_input_tensor(image)
# Run inference
output = self.model(input_tensor)
# Process output
if isinstance(output, tuple):
output = output[0]
# Convert to numpy mask
mask = output.squeeze().cpu().numpy()
# Threshold and convert to uint8
mask = (mask > 0.5).astype(np.uint8) * 255
# Resize to original size if needed
if mask.shape[:2] != image.shape[:2]:
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
return mask
except Exception as e:
self.logger.error(f"Segmentation failed: {e}")
if self.config.enable_fallback:
# Use basic segmentation as fallback
from ..processing.fallback import ProcessingFallback
fallback = ProcessingFallback()
return fallback.basic_segmentation(image)
raise
def _prepare_input_tensor(self, image: np.ndarray) -> torch.Tensor:
"""Prepare image tensor for model input."""
# Resize to model input size (typically 512x512 or 1024x1024)
model_size = 512 # Default, should be from model config
resized = cv2.resize(image, (model_size, model_size))
# Convert to tensor
tensor = torch.from_numpy(resized.transpose(2, 0, 1)).float()
tensor = tensor.unsqueeze(0) / 255.0
# Move to device
tensor = tensor.to(self.device_manager.get_device())
return tensor
def _combine_masks(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray:
"""Combine two masks intelligently."""
# Convert to float for blending
m1 = mask1.astype(np.float32) / 255.0
m2 = mask2.astype(np.float32) / 255.0
# Combine using maximum (union)
combined = np.maximum(m1, m2)
# Convert back to uint8
return (combined * 255).astype(np.uint8)
def _extract_foreground(self, image: np.ndarray,
alpha: np.ndarray) -> np.ndarray:
"""Extract foreground using alpha matte."""
if len(alpha.shape) == 2:
alpha = np.expand_dims(alpha, axis=2)
if alpha.shape[2] == 1:
alpha = np.repeat(alpha, 3, axis=2)
# Premultiply alpha
foreground = image.astype(np.float32) * alpha
return foreground.astype(np.uint8)
def _resize_background(self, background: np.ndarray,
target_shape: Tuple[int, int]) -> np.ndarray:
"""Resize background to match target shape."""
h, w = target_shape
bg_h, bg_w = background.shape[:2]
if (bg_h, bg_w) == (h, w):
return background
# Calculate scale to cover entire image
scale = max(h / bg_h, w / bg_w)
new_h = int(bg_h * scale)
new_w = int(bg_w * scale)
# Resize
resized = cv2.resize(background, (new_w, new_h),
interpolation=cv2.INTER_LINEAR)
# Center crop
start_y = (new_h - h) // 2
start_x = (new_w - w) // 2
cropped = resized[start_y:start_y + h, start_x:start_x + w]
return cropped
def _apply_effects(self, image: np.ndarray,
mask: np.ndarray) -> np.ndarray:
"""Apply configured effects to image."""
result = image.copy()
for effect in self.config.apply_effects:
if effect == EffectType.BOKEH:
result = self.background_effects.apply_bokeh(result)
elif effect == EffectType.VIGNETTE:
result = self.background_effects.add_vignette(result)
elif effect == EffectType.FILM_GRAIN:
result = self.background_effects.add_film_grain(result)
return result
def _apply_post_effects(self, image: np.ndarray,
mask: np.ndarray) -> np.ndarray:
"""Apply post-composite effects."""
result = image.copy()
for effect in self.config.apply_effects:
if effect == EffectType.SHADOW:
result = self.background_effects.add_shadow(result, mask)
elif effect == EffectType.REFLECTION:
result = self.background_effects.add_reflection(result, mask)
elif effect == EffectType.GLOW:
result = self.background_effects.add_glow(result, mask)
elif effect == EffectType.CHROMATIC_ABERRATION:
result = self.background_effects.chromatic_aberration(result)
return result
def _postprocess_image(self, image: np.ndarray,
alpha: np.ndarray) -> np.ndarray:
"""Apply final postprocessing."""
result = image.copy()
# Color correction based on quality preset
if self.config.quality_preset in ["high", "ultra"]:
# Auto color balance
lab = cv2.cvtColor(result, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
l = cv2.equalizeHist(l)
result = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR)
# Sharpen if ultra quality
if self.config.quality_preset == "ultra":
kernel = np.array([[-1,-1,-1],
[-1, 9,-1],
[-1,-1,-1]])
result = cv2.filter2D(result, -1, kernel)
return result
def _calculate_quality_score(self, image: np.ndarray,
alpha: np.ndarray,
metrics: Dict) -> float:
"""Calculate overall quality score."""
scores = []
# Edge quality
edge_score = metrics.get('edge_clarity', 0.5)
scores.append(edge_score)
# Alpha matte quality (contrast)
alpha_std = np.std(alpha)
alpha_score = min(alpha_std * 2, 1.0) # Higher std = better separation
scores.append(alpha_score)
# Overall image quality
quality_score = metrics.get('overall_quality', 0.5)
scores.append(quality_score)
return np.mean(scores)
def _load_image(self, source: Union[np.ndarray, str, Path]) -> np.ndarray:
"""Load image from various sources."""
if isinstance(source, np.ndarray):
return source
path = Path(source) if not isinstance(source, Path) else source
if not path.exists():
raise FileNotFoundError(f"Image not found: {path}")
image = cv2.imread(str(path))
if image is None:
raise ValueError(f"Failed to load image: {path}")
return image
def _generate_cache_key(self, image: np.ndarray,
params: Dict) -> str:
"""Generate cache key for result."""
# Create hash from image and parameters
hasher = hashlib.md5()
hasher.update(image.tobytes())
hasher.update(json.dumps(params, sort_keys=True).encode())
return hasher.hexdigest()
def _cache_result(self, key: str, result: PipelineResult):
"""Cache processing result."""
self.cache[key] = result
# Limit cache size
cache_memory = sum(
r.output_image.nbytes if r.output_image is not None else 0
for r in self.cache.values()
)
max_bytes = self.config.cache_size_mb * 1024 * 1024
if cache_memory > max_bytes:
# Remove oldest entries
for old_key in list(self.cache.keys())[:len(self.cache)//4]:
del self.cache[old_key]
def _update_stage(self, stage: PipelineStage):
"""Update current processing stage."""
self.current_stage = stage
if self.config.stage_callback:
self.config.stage_callback(stage, {
'timestamp': time.time(),
'memory_usage': self.memory_monitor.get_usage()
})
if self.config.progress_callback:
progress = list(PipelineStage).index(stage) / len(PipelineStage)
self.config.progress_callback(progress, stage.value)
def _update_statistics(self, result: PipelineResult):
"""Update processing statistics."""
if 'total_processed' not in self.processing_stats:
self.processing_stats['total_processed'] = 0
self.processing_stats['total_time'] = 0
self.processing_stats['avg_quality'] = 0
self.processing_stats['total_processed'] += 1
self.processing_stats['total_time'] += result.processing_time
self.processing_stats['avg_time'] = (
self.processing_stats['total_time'] /
self.processing_stats['total_processed']
)
# Update average quality
n = self.processing_stats['total_processed']
old_avg = self.processing_stats['avg_quality']
self.processing_stats['avg_quality'] = (
(old_avg * (n - 1) + result.quality_score) / n
)
def _fallback_processing(self, image: np.ndarray,
background: Optional[np.ndarray]) -> PipelineResult:
"""Fallback processing when main pipeline fails."""
from ..processing.fallback import ProcessingFallback
result = PipelineResult(success=False)
fallback = ProcessingFallback()
try:
# Basic segmentation
mask = fallback.basic_segmentation(image)
# Basic matting
alpha = fallback.basic_matting(image, mask)
# Simple composite if background provided
if background is not None:
background = self._resize_background(background, image.shape[:2])
output = self.compositing_engine.composite(
image, background, alpha
)
else:
output = image
result.success = True
result.output_image = output
result.alpha_matte = alpha
result.metadata['fallback_used'] = True
except Exception as e:
self.logger.error(f"Fallback processing also failed: {e}")
result.errors.append(str(e))
return result
def process_batch(self, images: List[Union[np.ndarray, str, Path]],
background: Optional[Union[np.ndarray, str, Path]] = None,
**kwargs) -> List[PipelineResult]:
"""
Process multiple images in batch.
Args:
images: List of input images
background: Optional background for all images
**kwargs: Additional processing parameters
Returns:
List of PipelineResults
"""
results = []
total = len(images)
self.logger.info(f"Processing batch of {total} images")
# Process in parallel using thread pool
futures = []
for i, image in enumerate(images):
future = self.executor.submit(
self.process_image, image, background, **kwargs
)
futures.append(future)
# Collect results
for i, future in enumerate(futures):
try:
result = future.result(timeout=30)
results.append(result)
if self.config.progress_callback:
progress = (i + 1) / total
self.config.progress_callback(
progress,
f"Processed {i + 1}/{total}"
)
except Exception as e:
self.logger.error(f"Batch item {i} failed: {e}")
results.append(PipelineResult(
success=False,
errors=[str(e)]
))
return results
def get_statistics(self) -> Dict[str, Any]:
"""Get processing statistics."""
return {
**self.processing_stats,
'cache_size': len(self.cache),
'current_stage': self.current_stage.value,
'is_processing': self.is_processing,
'device': str(self.device_manager.get_device()),
'model_type': self.config.model_type.value
}
def clear_cache(self):
"""Clear the result cache."""
self.cache.clear()
self.logger.info("Cache cleared")
def shutdown(self):
"""Shutdown the pipeline and cleanup resources."""
self.executor.shutdown(wait=True)
self.clear_cache()
# Cleanup models
if hasattr(self, 'model'):
del self.model
torch.cuda.empty_cache()
self.logger.info("Pipeline shutdown complete")