|
|
""" |
|
|
Main processing pipeline for BackgroundFX Pro. |
|
|
Orchestrates the complete background removal and replacement workflow. |
|
|
""" |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from typing import Dict, List, Optional, Tuple, Union, Callable, Any |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from pathlib import Path |
|
|
import time |
|
|
import threading |
|
|
from queue import Queue |
|
|
import json |
|
|
import hashlib |
|
|
from concurrent.futures import ThreadPoolExecutor, Future |
|
|
|
|
|
from ..utils.logger import setup_logger |
|
|
from ..utils.device import DeviceManager |
|
|
from ..utils.config import ConfigManager |
|
|
from ..utils import TimeEstimator, MemoryMonitor |
|
|
|
|
|
from ..core.models import ModelFactory, ModelType |
|
|
from ..core.temporal import TemporalCoherence |
|
|
from ..core.quality import QualityAnalyzer |
|
|
from ..core.edge import EdgeRefinement |
|
|
from ..core.hair_segmentation import HairSegmentation |
|
|
|
|
|
from ..processing.matting import AlphaMatting, MattingConfig, CompositingEngine |
|
|
from ..processing.fallback import FallbackStrategy, FallbackLevel |
|
|
from ..processing.effects import BackgroundEffects, CompositeEffects, EffectType |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
class ProcessingMode(Enum): |
|
|
"""Processing mode types.""" |
|
|
PHOTO = "photo" |
|
|
VIDEO = "video" |
|
|
REALTIME = "realtime" |
|
|
BATCH = "batch" |
|
|
|
|
|
|
|
|
class PipelineStage(Enum): |
|
|
"""Pipeline processing stages.""" |
|
|
INITIALIZATION = "initialization" |
|
|
PREPROCESSING = "preprocessing" |
|
|
SEGMENTATION = "segmentation" |
|
|
MATTING = "matting" |
|
|
REFINEMENT = "refinement" |
|
|
EFFECTS = "effects" |
|
|
COMPOSITING = "compositing" |
|
|
POSTPROCESSING = "postprocessing" |
|
|
COMPLETE = "complete" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PipelineConfig: |
|
|
"""Configuration for the processing pipeline.""" |
|
|
|
|
|
model_type: ModelType = ModelType.RMBG_1_4 |
|
|
use_gpu: bool = True |
|
|
device: Optional[str] = None |
|
|
|
|
|
|
|
|
mode: ProcessingMode = ProcessingMode.PHOTO |
|
|
enable_temporal: bool = True |
|
|
enable_hair_refinement: bool = True |
|
|
enable_edge_refinement: bool = True |
|
|
enable_fallback: bool = True |
|
|
|
|
|
|
|
|
quality_preset: str = "high" |
|
|
target_resolution: Optional[Tuple[int, int]] = None |
|
|
maintain_aspect_ratio: bool = True |
|
|
|
|
|
|
|
|
matting_method: str = "auto" |
|
|
matting_config: MattingConfig = field(default_factory=MattingConfig) |
|
|
|
|
|
|
|
|
background_blur: bool = False |
|
|
blur_strength: float = 15.0 |
|
|
apply_effects: List[EffectType] = field(default_factory=list) |
|
|
|
|
|
|
|
|
batch_size: int = 1 |
|
|
num_workers: int = 4 |
|
|
enable_caching: bool = True |
|
|
cache_size_mb: int = 500 |
|
|
|
|
|
|
|
|
output_format: str = "png" |
|
|
output_quality: int = 95 |
|
|
preserve_metadata: bool = True |
|
|
|
|
|
|
|
|
progress_callback: Optional[Callable[[float, str], None]] = None |
|
|
stage_callback: Optional[Callable[[PipelineStage, Dict], None]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PipelineResult: |
|
|
"""Result from pipeline processing.""" |
|
|
success: bool |
|
|
output_image: Optional[np.ndarray] = None |
|
|
alpha_matte: Optional[np.ndarray] = None |
|
|
foreground: Optional[np.ndarray] = None |
|
|
background: Optional[np.ndarray] = None |
|
|
metadata: Dict[str, Any] = field(default_factory=dict) |
|
|
processing_time: float = 0.0 |
|
|
stages_completed: List[PipelineStage] = field(default_factory=list) |
|
|
errors: List[str] = field(default_factory=list) |
|
|
quality_score: float = 0.0 |
|
|
|
|
|
|
|
|
class ProcessingPipeline: |
|
|
"""Main processing pipeline orchestrator.""" |
|
|
|
|
|
def __init__(self, config: Optional[PipelineConfig] = None): |
|
|
""" |
|
|
Initialize the processing pipeline. |
|
|
|
|
|
Args: |
|
|
config: Pipeline configuration |
|
|
""" |
|
|
self.config = config or PipelineConfig() |
|
|
self.logger = setup_logger(f"{__name__}.ProcessingPipeline") |
|
|
|
|
|
|
|
|
self._initialize_components() |
|
|
|
|
|
|
|
|
self.current_stage = PipelineStage.INITIALIZATION |
|
|
self.processing_stats = {} |
|
|
self.cache = {} |
|
|
self.is_processing = False |
|
|
|
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=self.config.num_workers) |
|
|
|
|
|
self.logger.info("Pipeline initialized successfully") |
|
|
|
|
|
def _initialize_components(self): |
|
|
"""Initialize all pipeline components.""" |
|
|
try: |
|
|
|
|
|
self.device_manager = DeviceManager() |
|
|
if self.config.device: |
|
|
self.device_manager.set_device(self.config.device) |
|
|
elif not self.config.use_gpu: |
|
|
self.device_manager.set_device('cpu') |
|
|
|
|
|
|
|
|
self.model_factory = ModelFactory() |
|
|
self.quality_analyzer = QualityAnalyzer() |
|
|
self.edge_refinement = EdgeRefinement() |
|
|
self.temporal_coherence = TemporalCoherence() if self.config.enable_temporal else None |
|
|
self.hair_segmentation = HairSegmentation() if self.config.enable_hair_refinement else None |
|
|
|
|
|
|
|
|
self.alpha_matting = AlphaMatting(self.config.matting_config) |
|
|
self.compositing_engine = CompositingEngine() |
|
|
self.background_effects = BackgroundEffects() |
|
|
self.composite_effects = CompositeEffects() |
|
|
|
|
|
|
|
|
self.fallback_strategy = FallbackStrategy() if self.config.enable_fallback else None |
|
|
|
|
|
|
|
|
self.memory_monitor = MemoryMonitor() |
|
|
self.time_estimator = TimeEstimator() |
|
|
|
|
|
|
|
|
self._load_model() |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Component initialization failed: {e}") |
|
|
raise |
|
|
|
|
|
def _load_model(self): |
|
|
"""Load the segmentation model.""" |
|
|
try: |
|
|
self.logger.info(f"Loading model: {self.config.model_type.value}") |
|
|
|
|
|
self.model = self.model_factory.load_model( |
|
|
self.config.model_type, |
|
|
device=self.device_manager.get_device(), |
|
|
optimize=True |
|
|
) |
|
|
|
|
|
self.logger.info("Model loaded successfully") |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Model loading failed: {e}") |
|
|
if self.config.enable_fallback: |
|
|
self.logger.info("Attempting fallback model loading") |
|
|
self.config.model_type = ModelType.U2NET_LITE |
|
|
self.model = self.model_factory.load_model( |
|
|
self.config.model_type, |
|
|
device='cpu' |
|
|
) |
|
|
|
|
|
def process_image(self, |
|
|
image: Union[np.ndarray, str, Path], |
|
|
background: Optional[Union[np.ndarray, str, Path]] = None, |
|
|
**kwargs) -> PipelineResult: |
|
|
""" |
|
|
Process a single image through the pipeline. |
|
|
|
|
|
Args: |
|
|
image: Input image (array or path) |
|
|
background: Optional background image/path |
|
|
**kwargs: Additional processing parameters |
|
|
|
|
|
Returns: |
|
|
PipelineResult with processed image and metadata |
|
|
""" |
|
|
start_time = time.time() |
|
|
self.is_processing = True |
|
|
result = PipelineResult(success=False) |
|
|
|
|
|
try: |
|
|
|
|
|
self._update_stage(PipelineStage.INITIALIZATION) |
|
|
image_array = self._load_image(image) |
|
|
bg_array = self._load_image(background) if background is not None else None |
|
|
|
|
|
|
|
|
cache_key = self._generate_cache_key(image_array, kwargs) |
|
|
|
|
|
|
|
|
if self.config.enable_caching and cache_key in self.cache: |
|
|
self.logger.info("Using cached result") |
|
|
cached_result = self.cache[cache_key] |
|
|
cached_result.processing_time = time.time() - start_time |
|
|
return cached_result |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.PREPROCESSING) |
|
|
preprocessed = self._preprocess_image(image_array) |
|
|
result.metadata['original_size'] = image_array.shape[:2] |
|
|
result.metadata['preprocessed_size'] = preprocessed.shape[:2] |
|
|
|
|
|
|
|
|
quality_metrics = self.quality_analyzer.analyze_frame(preprocessed) |
|
|
result.metadata['quality_metrics'] = quality_metrics |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.SEGMENTATION) |
|
|
segmentation_mask = self._segment_image(preprocessed) |
|
|
|
|
|
|
|
|
if self.config.enable_hair_refinement: |
|
|
self.logger.info("Applying hair refinement") |
|
|
hair_mask = self.hair_segmentation.segment_hair(preprocessed) |
|
|
segmentation_mask = self._combine_masks(segmentation_mask, hair_mask) |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.MATTING) |
|
|
matting_result = self.alpha_matting.process( |
|
|
preprocessed, |
|
|
segmentation_mask, |
|
|
method=self.config.matting_method |
|
|
) |
|
|
alpha_matte = matting_result['alpha'] |
|
|
result.metadata['matting_confidence'] = matting_result['confidence'] |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.REFINEMENT) |
|
|
if self.config.enable_edge_refinement: |
|
|
alpha_matte = self.edge_refinement.refine_edges( |
|
|
preprocessed, |
|
|
(alpha_matte * 255).astype(np.uint8) |
|
|
) / 255.0 |
|
|
|
|
|
|
|
|
if preprocessed.shape[:2] != image_array.shape[:2]: |
|
|
alpha_matte = cv2.resize( |
|
|
alpha_matte, |
|
|
(image_array.shape[1], image_array.shape[0]), |
|
|
interpolation=cv2.INTER_LINEAR |
|
|
) |
|
|
|
|
|
|
|
|
foreground = self._extract_foreground(image_array, alpha_matte) |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.EFFECTS) |
|
|
|
|
|
if bg_array is not None: |
|
|
|
|
|
bg_array = self._resize_background(bg_array, image_array.shape[:2]) |
|
|
|
|
|
|
|
|
if self.config.background_blur: |
|
|
bg_array = self.background_effects.apply_blur( |
|
|
bg_array, |
|
|
strength=self.config.blur_strength, |
|
|
mask=1 - alpha_matte |
|
|
) |
|
|
|
|
|
|
|
|
if self.config.apply_effects: |
|
|
bg_array = self._apply_effects(bg_array, alpha_matte) |
|
|
else: |
|
|
|
|
|
bg_array = np.zeros_like(image_array) |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.COMPOSITING) |
|
|
|
|
|
if self.config.apply_effects and EffectType.LIGHT_WRAP in self.config.apply_effects: |
|
|
foreground = self.background_effects.apply_light_wrap( |
|
|
foreground, bg_array, alpha_matte |
|
|
) |
|
|
|
|
|
composited = self.compositing_engine.composite( |
|
|
foreground, bg_array, alpha_matte |
|
|
) |
|
|
|
|
|
|
|
|
if self.config.apply_effects: |
|
|
composited = self._apply_post_effects(composited, alpha_matte) |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.POSTPROCESSING) |
|
|
final_output = self._postprocess_image(composited, alpha_matte) |
|
|
|
|
|
|
|
|
result.quality_score = self._calculate_quality_score( |
|
|
final_output, alpha_matte, quality_metrics |
|
|
) |
|
|
|
|
|
|
|
|
result.success = True |
|
|
result.output_image = final_output |
|
|
result.alpha_matte = alpha_matte |
|
|
result.foreground = foreground |
|
|
result.background = bg_array |
|
|
result.stages_completed = list(PipelineStage) |
|
|
result.processing_time = time.time() - start_time |
|
|
|
|
|
|
|
|
if self.config.enable_caching: |
|
|
self._cache_result(cache_key, result) |
|
|
|
|
|
|
|
|
self._update_stage(PipelineStage.COMPLETE) |
|
|
self.logger.info(f"Processing completed in {result.processing_time:.2f}s") |
|
|
|
|
|
|
|
|
self._update_statistics(result) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Pipeline processing failed: {e}") |
|
|
result.errors.append(str(e)) |
|
|
|
|
|
if self.config.enable_fallback and self.fallback_strategy: |
|
|
self.logger.info("Attempting fallback processing") |
|
|
result = self._fallback_processing(image_array, bg_array) |
|
|
|
|
|
finally: |
|
|
self.is_processing = False |
|
|
|
|
|
return result |
|
|
|
|
|
def _preprocess_image(self, image: np.ndarray) -> np.ndarray: |
|
|
"""Preprocess image for optimal processing.""" |
|
|
processed = image.copy() |
|
|
|
|
|
|
|
|
if self.config.target_resolution: |
|
|
target_h, target_w = self.config.target_resolution |
|
|
h, w = image.shape[:2] |
|
|
|
|
|
if self.config.maintain_aspect_ratio: |
|
|
scale = min(target_w / w, target_h / h) |
|
|
new_w = int(w * scale) |
|
|
new_h = int(h * scale) |
|
|
else: |
|
|
new_w, new_h = target_w, target_h |
|
|
|
|
|
if (new_w, new_h) != (w, h): |
|
|
processed = cv2.resize(processed, (new_w, new_h), |
|
|
interpolation=cv2.INTER_AREA) |
|
|
|
|
|
|
|
|
if self.config.quality_preset == "low": |
|
|
|
|
|
processed = cv2.fastNlMeansDenoising(processed, None, 10, 7, 21) |
|
|
elif self.config.quality_preset in ["high", "ultra"]: |
|
|
|
|
|
processed = cv2.detailEnhance(processed, sigma_s=10, sigma_r=0.15) |
|
|
|
|
|
return processed |
|
|
|
|
|
def _segment_image(self, image: np.ndarray) -> np.ndarray: |
|
|
"""Perform image segmentation.""" |
|
|
try: |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
input_tensor = self._prepare_input_tensor(image) |
|
|
|
|
|
|
|
|
output = self.model(input_tensor) |
|
|
|
|
|
|
|
|
if isinstance(output, tuple): |
|
|
output = output[0] |
|
|
|
|
|
|
|
|
mask = output.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
mask = (mask > 0.5).astype(np.uint8) * 255 |
|
|
|
|
|
|
|
|
if mask.shape[:2] != image.shape[:2]: |
|
|
mask = cv2.resize(mask, (image.shape[1], image.shape[0])) |
|
|
|
|
|
return mask |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Segmentation failed: {e}") |
|
|
if self.config.enable_fallback: |
|
|
|
|
|
from ..processing.fallback import ProcessingFallback |
|
|
fallback = ProcessingFallback() |
|
|
return fallback.basic_segmentation(image) |
|
|
raise |
|
|
|
|
|
def _prepare_input_tensor(self, image: np.ndarray) -> torch.Tensor: |
|
|
"""Prepare image tensor for model input.""" |
|
|
|
|
|
model_size = 512 |
|
|
resized = cv2.resize(image, (model_size, model_size)) |
|
|
|
|
|
|
|
|
tensor = torch.from_numpy(resized.transpose(2, 0, 1)).float() |
|
|
tensor = tensor.unsqueeze(0) / 255.0 |
|
|
|
|
|
|
|
|
tensor = tensor.to(self.device_manager.get_device()) |
|
|
|
|
|
return tensor |
|
|
|
|
|
def _combine_masks(self, mask1: np.ndarray, mask2: np.ndarray) -> np.ndarray: |
|
|
"""Combine two masks intelligently.""" |
|
|
|
|
|
m1 = mask1.astype(np.float32) / 255.0 |
|
|
m2 = mask2.astype(np.float32) / 255.0 |
|
|
|
|
|
|
|
|
combined = np.maximum(m1, m2) |
|
|
|
|
|
|
|
|
return (combined * 255).astype(np.uint8) |
|
|
|
|
|
def _extract_foreground(self, image: np.ndarray, |
|
|
alpha: np.ndarray) -> np.ndarray: |
|
|
"""Extract foreground using alpha matte.""" |
|
|
if len(alpha.shape) == 2: |
|
|
alpha = np.expand_dims(alpha, axis=2) |
|
|
|
|
|
if alpha.shape[2] == 1: |
|
|
alpha = np.repeat(alpha, 3, axis=2) |
|
|
|
|
|
|
|
|
foreground = image.astype(np.float32) * alpha |
|
|
|
|
|
return foreground.astype(np.uint8) |
|
|
|
|
|
def _resize_background(self, background: np.ndarray, |
|
|
target_shape: Tuple[int, int]) -> np.ndarray: |
|
|
"""Resize background to match target shape.""" |
|
|
h, w = target_shape |
|
|
bg_h, bg_w = background.shape[:2] |
|
|
|
|
|
if (bg_h, bg_w) == (h, w): |
|
|
return background |
|
|
|
|
|
|
|
|
scale = max(h / bg_h, w / bg_w) |
|
|
new_h = int(bg_h * scale) |
|
|
new_w = int(bg_w * scale) |
|
|
|
|
|
|
|
|
resized = cv2.resize(background, (new_w, new_h), |
|
|
interpolation=cv2.INTER_LINEAR) |
|
|
|
|
|
|
|
|
start_y = (new_h - h) // 2 |
|
|
start_x = (new_w - w) // 2 |
|
|
cropped = resized[start_y:start_y + h, start_x:start_x + w] |
|
|
|
|
|
return cropped |
|
|
|
|
|
def _apply_effects(self, image: np.ndarray, |
|
|
mask: np.ndarray) -> np.ndarray: |
|
|
"""Apply configured effects to image.""" |
|
|
result = image.copy() |
|
|
|
|
|
for effect in self.config.apply_effects: |
|
|
if effect == EffectType.BOKEH: |
|
|
result = self.background_effects.apply_bokeh(result) |
|
|
elif effect == EffectType.VIGNETTE: |
|
|
result = self.background_effects.add_vignette(result) |
|
|
elif effect == EffectType.FILM_GRAIN: |
|
|
result = self.background_effects.add_film_grain(result) |
|
|
|
|
|
return result |
|
|
|
|
|
def _apply_post_effects(self, image: np.ndarray, |
|
|
mask: np.ndarray) -> np.ndarray: |
|
|
"""Apply post-composite effects.""" |
|
|
result = image.copy() |
|
|
|
|
|
for effect in self.config.apply_effects: |
|
|
if effect == EffectType.SHADOW: |
|
|
result = self.background_effects.add_shadow(result, mask) |
|
|
elif effect == EffectType.REFLECTION: |
|
|
result = self.background_effects.add_reflection(result, mask) |
|
|
elif effect == EffectType.GLOW: |
|
|
result = self.background_effects.add_glow(result, mask) |
|
|
elif effect == EffectType.CHROMATIC_ABERRATION: |
|
|
result = self.background_effects.chromatic_aberration(result) |
|
|
|
|
|
return result |
|
|
|
|
|
def _postprocess_image(self, image: np.ndarray, |
|
|
alpha: np.ndarray) -> np.ndarray: |
|
|
"""Apply final postprocessing.""" |
|
|
result = image.copy() |
|
|
|
|
|
|
|
|
if self.config.quality_preset in ["high", "ultra"]: |
|
|
|
|
|
lab = cv2.cvtColor(result, cv2.COLOR_BGR2LAB) |
|
|
l, a, b = cv2.split(lab) |
|
|
l = cv2.equalizeHist(l) |
|
|
result = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2BGR) |
|
|
|
|
|
|
|
|
if self.config.quality_preset == "ultra": |
|
|
kernel = np.array([[-1,-1,-1], |
|
|
[-1, 9,-1], |
|
|
[-1,-1,-1]]) |
|
|
result = cv2.filter2D(result, -1, kernel) |
|
|
|
|
|
return result |
|
|
|
|
|
def _calculate_quality_score(self, image: np.ndarray, |
|
|
alpha: np.ndarray, |
|
|
metrics: Dict) -> float: |
|
|
"""Calculate overall quality score.""" |
|
|
scores = [] |
|
|
|
|
|
|
|
|
edge_score = metrics.get('edge_clarity', 0.5) |
|
|
scores.append(edge_score) |
|
|
|
|
|
|
|
|
alpha_std = np.std(alpha) |
|
|
alpha_score = min(alpha_std * 2, 1.0) |
|
|
scores.append(alpha_score) |
|
|
|
|
|
|
|
|
quality_score = metrics.get('overall_quality', 0.5) |
|
|
scores.append(quality_score) |
|
|
|
|
|
return np.mean(scores) |
|
|
|
|
|
def _load_image(self, source: Union[np.ndarray, str, Path]) -> np.ndarray: |
|
|
"""Load image from various sources.""" |
|
|
if isinstance(source, np.ndarray): |
|
|
return source |
|
|
|
|
|
path = Path(source) if not isinstance(source, Path) else source |
|
|
if not path.exists(): |
|
|
raise FileNotFoundError(f"Image not found: {path}") |
|
|
|
|
|
image = cv2.imread(str(path)) |
|
|
if image is None: |
|
|
raise ValueError(f"Failed to load image: {path}") |
|
|
|
|
|
return image |
|
|
|
|
|
def _generate_cache_key(self, image: np.ndarray, |
|
|
params: Dict) -> str: |
|
|
"""Generate cache key for result.""" |
|
|
|
|
|
hasher = hashlib.md5() |
|
|
hasher.update(image.tobytes()) |
|
|
hasher.update(json.dumps(params, sort_keys=True).encode()) |
|
|
return hasher.hexdigest() |
|
|
|
|
|
def _cache_result(self, key: str, result: PipelineResult): |
|
|
"""Cache processing result.""" |
|
|
self.cache[key] = result |
|
|
|
|
|
|
|
|
cache_memory = sum( |
|
|
r.output_image.nbytes if r.output_image is not None else 0 |
|
|
for r in self.cache.values() |
|
|
) |
|
|
|
|
|
max_bytes = self.config.cache_size_mb * 1024 * 1024 |
|
|
|
|
|
if cache_memory > max_bytes: |
|
|
|
|
|
for old_key in list(self.cache.keys())[:len(self.cache)//4]: |
|
|
del self.cache[old_key] |
|
|
|
|
|
def _update_stage(self, stage: PipelineStage): |
|
|
"""Update current processing stage.""" |
|
|
self.current_stage = stage |
|
|
|
|
|
if self.config.stage_callback: |
|
|
self.config.stage_callback(stage, { |
|
|
'timestamp': time.time(), |
|
|
'memory_usage': self.memory_monitor.get_usage() |
|
|
}) |
|
|
|
|
|
if self.config.progress_callback: |
|
|
progress = list(PipelineStage).index(stage) / len(PipelineStage) |
|
|
self.config.progress_callback(progress, stage.value) |
|
|
|
|
|
def _update_statistics(self, result: PipelineResult): |
|
|
"""Update processing statistics.""" |
|
|
if 'total_processed' not in self.processing_stats: |
|
|
self.processing_stats['total_processed'] = 0 |
|
|
self.processing_stats['total_time'] = 0 |
|
|
self.processing_stats['avg_quality'] = 0 |
|
|
|
|
|
self.processing_stats['total_processed'] += 1 |
|
|
self.processing_stats['total_time'] += result.processing_time |
|
|
self.processing_stats['avg_time'] = ( |
|
|
self.processing_stats['total_time'] / |
|
|
self.processing_stats['total_processed'] |
|
|
) |
|
|
|
|
|
|
|
|
n = self.processing_stats['total_processed'] |
|
|
old_avg = self.processing_stats['avg_quality'] |
|
|
self.processing_stats['avg_quality'] = ( |
|
|
(old_avg * (n - 1) + result.quality_score) / n |
|
|
) |
|
|
|
|
|
def _fallback_processing(self, image: np.ndarray, |
|
|
background: Optional[np.ndarray]) -> PipelineResult: |
|
|
"""Fallback processing when main pipeline fails.""" |
|
|
from ..processing.fallback import ProcessingFallback |
|
|
|
|
|
result = PipelineResult(success=False) |
|
|
fallback = ProcessingFallback() |
|
|
|
|
|
try: |
|
|
|
|
|
mask = fallback.basic_segmentation(image) |
|
|
|
|
|
|
|
|
alpha = fallback.basic_matting(image, mask) |
|
|
|
|
|
|
|
|
if background is not None: |
|
|
background = self._resize_background(background, image.shape[:2]) |
|
|
output = self.compositing_engine.composite( |
|
|
image, background, alpha |
|
|
) |
|
|
else: |
|
|
output = image |
|
|
|
|
|
result.success = True |
|
|
result.output_image = output |
|
|
result.alpha_matte = alpha |
|
|
result.metadata['fallback_used'] = True |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Fallback processing also failed: {e}") |
|
|
result.errors.append(str(e)) |
|
|
|
|
|
return result |
|
|
|
|
|
def process_batch(self, images: List[Union[np.ndarray, str, Path]], |
|
|
background: Optional[Union[np.ndarray, str, Path]] = None, |
|
|
**kwargs) -> List[PipelineResult]: |
|
|
""" |
|
|
Process multiple images in batch. |
|
|
|
|
|
Args: |
|
|
images: List of input images |
|
|
background: Optional background for all images |
|
|
**kwargs: Additional processing parameters |
|
|
|
|
|
Returns: |
|
|
List of PipelineResults |
|
|
""" |
|
|
results = [] |
|
|
total = len(images) |
|
|
|
|
|
self.logger.info(f"Processing batch of {total} images") |
|
|
|
|
|
|
|
|
futures = [] |
|
|
for i, image in enumerate(images): |
|
|
future = self.executor.submit( |
|
|
self.process_image, image, background, **kwargs |
|
|
) |
|
|
futures.append(future) |
|
|
|
|
|
|
|
|
for i, future in enumerate(futures): |
|
|
try: |
|
|
result = future.result(timeout=30) |
|
|
results.append(result) |
|
|
|
|
|
if self.config.progress_callback: |
|
|
progress = (i + 1) / total |
|
|
self.config.progress_callback( |
|
|
progress, |
|
|
f"Processed {i + 1}/{total}" |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
self.logger.error(f"Batch item {i} failed: {e}") |
|
|
results.append(PipelineResult( |
|
|
success=False, |
|
|
errors=[str(e)] |
|
|
)) |
|
|
|
|
|
return results |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
"""Get processing statistics.""" |
|
|
return { |
|
|
**self.processing_stats, |
|
|
'cache_size': len(self.cache), |
|
|
'current_stage': self.current_stage.value, |
|
|
'is_processing': self.is_processing, |
|
|
'device': str(self.device_manager.get_device()), |
|
|
'model_type': self.config.model_type.value |
|
|
} |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear the result cache.""" |
|
|
self.cache.clear() |
|
|
self.logger.info("Cache cleared") |
|
|
|
|
|
def shutdown(self): |
|
|
"""Shutdown the pipeline and cleanup resources.""" |
|
|
self.executor.shutdown(wait=True) |
|
|
self.clear_cache() |
|
|
|
|
|
|
|
|
if hasattr(self, 'model'): |
|
|
del self.model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.logger.info("Pipeline shutdown complete") |