#!/usr/bin/env python3 """ Fallback strategies for BackgroundFX Pro. Implements robust fallback mechanisms when primary processing fails. """ import cv2 import numpy as np import torch from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass from enum import Enum import logging import traceback # ABSOLUTE IMPORTS for Hugging Face Spaces from utils.logger import setup_logger from utils.device import DeviceManager from utils.config import ConfigManager from core.quality import QualityAnalyzer logger = setup_logger(__name__) class FallbackLevel(Enum): NONE = 0 QUALITY_REDUCTION = 1 METHOD_SWITCH = 2 BASIC_PROCESSING = 3 MINIMAL_PROCESSING = 4 PASSTHROUGH = 5 @dataclass class FallbackConfig: max_retries: int = 3 quality_reduction_factor: float = 0.75 min_quality: float = 0.3 enable_caching: bool = True cache_size: int = 10 timeout_seconds: float = 30.0 gpu_fallback_to_cpu: bool = True progressive_downscale: bool = True min_resolution: Tuple[int, int] = (320, 240) class FallbackStrategy: def __init__(self, config: Optional[FallbackConfig] = None): self.config = config or FallbackConfig() self.device_manager = DeviceManager() self.quality_analyzer = QualityAnalyzer() self.cache = {} self.fallback_history = [] self.current_level = FallbackLevel.NONE def execute_with_fallback(self, func, *args, **kwargs) -> Dict[str, Any]: attempt = 0 last_error = None original_args = args original_kwargs = kwargs.copy() while attempt < self.config.max_retries: try: logger.info(f"Attempt {attempt + 1}/{self.config.max_retries} for {func.__name__}") result = func(*args, **kwargs) self.current_level = FallbackLevel.NONE return { 'success': True, 'result': result, 'attempts': attempt + 1, 'fallback_level': self.current_level } except Exception as e: last_error = e logger.warning(f"Attempt {attempt + 1} failed: {str(e)}") fallback_result = self._apply_fallback(func, e, attempt, original_args, original_kwargs) if fallback_result['handled']: args = fallback_result.get('new_args', args) kwargs = fallback_result.get('new_kwargs', kwargs) else: break attempt += 1 logger.error(f"All attempts failed for {func.__name__}") return self._final_fallback(func, last_error, original_args) def _apply_fallback(self, func, error: Exception, attempt: int, original_args: tuple, original_kwargs: dict) -> Dict[str, Any]: error_type = type(error).__name__ self.fallback_history.append({ 'function': func.__name__, 'error': error_type, 'attempt': attempt }) if 'CUDA' in str(error) or 'GPU' in str(error): return self._handle_gpu_error(original_kwargs) elif 'memory' in str(error).lower(): return self._handle_memory_error(original_args, original_kwargs) elif 'timeout' in str(error).lower(): return self._handle_timeout_error(original_kwargs) elif 'model' in str(error).lower(): return self._handle_model_error(original_kwargs) else: return self._handle_generic_error(attempt, original_kwargs) def _handle_gpu_error(self, kwargs: dict) -> Dict[str, Any]: logger.info("GPU error detected, falling back to CPU") if self.config.gpu_fallback_to_cpu: self.device_manager.device = torch.device('cpu') kwargs['device'] = 'cpu' if 'batch_size' in kwargs: kwargs['batch_size'] = max(1, kwargs['batch_size'] // 2) self.current_level = FallbackLevel.METHOD_SWITCH return { 'handled': True, 'new_kwargs': kwargs } return {'handled': False} def _handle_memory_error(self, args: tuple, kwargs: dict) -> Dict[str, Any]: logger.info("Memory error detected, reducing quality") image = None image_idx = -1 for i, arg in enumerate(args): if isinstance(arg, np.ndarray) and len(arg.shape) == 3: image = arg image_idx = i break if image is not None and self.config.progressive_downscale: h, w = image.shape[:2] new_h = int(h * self.config.quality_reduction_factor) new_w = int(w * self.config.quality_reduction_factor) new_h = max(new_h, self.config.min_resolution[1]) new_w = max(new_w, self.config.min_resolution[0]) if new_h < h or new_w < w: resized = cv2.resize(image, (new_w, new_h)) args = list(args) args[image_idx] = resized self.current_level = FallbackLevel.QUALITY_REDUCTION return { 'handled': True, 'new_args': tuple(args), 'new_kwargs': kwargs } if 'quality' in kwargs: kwargs['quality'] = max( self.config.min_quality, kwargs['quality'] * self.config.quality_reduction_factor ) return { 'handled': True, 'new_kwargs': kwargs } def _handle_timeout_error(self, kwargs: dict) -> Dict[str, Any]: logger.info("Timeout detected, simplifying processing") simplifications = { 'use_refinement': False, 'use_temporal': False, 'use_guided_filter': False, 'iterations': 1, 'num_samples': 1 } for key, value in simplifications.items(): if key in kwargs: kwargs[key] = value self.current_level = FallbackLevel.BASIC_PROCESSING return { 'handled': True, 'new_kwargs': kwargs } def _handle_model_error(self, kwargs: dict) -> Dict[str, Any]: logger.info("Model error detected, using simpler model") if 'model_type' in kwargs: model_hierarchy = ['large', 'base', 'small', 'tiny'] current = kwargs.get('model_type', 'base') if current in model_hierarchy: idx = model_hierarchy.index(current) if idx < len(model_hierarchy) - 1: kwargs['model_type'] = model_hierarchy[idx + 1] self.current_level = FallbackLevel.METHOD_SWITCH return { 'handled': True, 'new_kwargs': kwargs } kwargs['use_model'] = False self.current_level = FallbackLevel.BASIC_PROCESSING return { 'handled': True, 'new_kwargs': kwargs } def _handle_generic_error(self, attempt: int, kwargs: dict) -> Dict[str, Any]: logger.info(f"Generic error, applying degradation level {attempt + 1}") if attempt == 0: self.current_level = FallbackLevel.QUALITY_REDUCTION if 'quality' in kwargs: kwargs['quality'] *= 0.8 elif attempt == 1: self.current_level = FallbackLevel.METHOD_SWITCH kwargs['method'] = 'basic' else: self.current_level = FallbackLevel.MINIMAL_PROCESSING kwargs['skip_refinement'] = True kwargs['fast_mode'] = True return { 'handled': True, 'new_kwargs': kwargs } def _final_fallback(self, func, error: Exception, original_args: tuple) -> Dict[str, Any]: logger.error(f"Final fallback for {func.__name__}: {str(error)}") self.current_level = FallbackLevel.PASSTHROUGH for arg in original_args: if isinstance(arg, np.ndarray): return { 'success': False, 'result': arg, 'fallback_level': self.current_level, 'error': str(error) } return { 'success': False, 'result': None, 'fallback_level': self.current_level, 'error': str(error) } class ProcessingFallback: def __init__(self): self.logger = setup_logger(f"{__name__}.ProcessingFallback") self.quality_analyzer = QualityAnalyzer() def basic_segmentation(self, image: np.ndarray) -> np.ndarray: try: if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image mask = np.zeros(gray.shape[:2], np.uint8) bgd_model = np.zeros((1, 65), np.float64) fgd_model = np.zeros((1, 65), np.float64) h, w = gray.shape[:2] rect = (int(w * 0.1), int(h * 0.1), int(w * 0.8), int(h * 0.8)) cv2.grabCut(image, mask, rect, bgd_model, fgd_model, 5, cv2.GC_INIT_WITH_RECT) mask2 = np.where((mask == 2) | (mask == 0), 0, 255).astype('uint8') return mask2 except Exception as e: self.logger.error(f"Basic segmentation failed: {e}") return self._center_blob_mask(image.shape[:2]) def _center_blob_mask(self, shape: Tuple[int, int]) -> np.ndarray: h, w = shape mask = np.zeros((h, w), dtype=np.uint8) center = (w // 2, h // 2) axes = (w // 3, h // 3) cv2.ellipse(mask, center, axes, 0, 0, 360, 255, -1) mask = cv2.GaussianBlur(mask, (21, 21), 10) _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) return mask def basic_matting(self, image: np.ndarray, mask: np.ndarray) -> np.ndarray: try: if mask.dtype != np.uint8: mask = (mask * 255).astype(np.uint8) kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) mask = cv2.GaussianBlur(mask, (5, 5), 2) alpha = mask.astype(np.float32) / 255.0 return alpha except Exception as e: self.logger.error(f"Basic matting failed: {e}") return mask.astype(np.float32) / 255.0 def color_difference_keying(self, image: np.ndarray, key_color: Optional[np.ndarray] = None, threshold: float = 30) -> np.ndarray: try: if key_color is None: h, w = image.shape[:2] corners = [ image[0:10, 0:10], image[0:10, w-10:w], image[h-10:h, 0:10], image[h-10:h, w-10:w] ] key_color = np.mean([np.mean(c, axis=(0, 1)) for c in corners], axis=0) diff = np.sqrt(np.sum((image - key_color) ** 2, axis=2)) mask = (diff > threshold).astype(np.float32) mask = cv2.GaussianBlur(mask, (5, 5), 2) return mask except Exception as e: self.logger.error(f"Color keying failed: {e}") return np.ones(image.shape[:2], dtype=np.float32) def edge_based_segmentation(self, image: np.ndarray) -> np.ndarray: try: if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image edges = cv2.Canny(gray, 50, 150) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 5)) closed = cv2.morphologyEx(edges, cv2.MORPH_CLOSE, kernel, iterations=2) contours, _ = cv2.findContours(closed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) mask = np.zeros(gray.shape, dtype=np.uint8) if contours: largest = max(contours, key=cv2.contourArea) cv2.drawContours(mask, [largest], -1, 255, -1) return mask except Exception as e: self.logger.error(f"Edge segmentation failed: {e}") return self._center_blob_mask(image.shape[:2]) def cached_result(self, cache_key: str, fallback_func, *args, **kwargs) -> Any: if not hasattr(self, '_cache'): self._cache = {} if cache_key in self._cache: self.logger.info(f"Using cached result for {cache_key}") return self._cache[cache_key] try: result = fallback_func(*args, **kwargs) self._cache[cache_key] = result if len(self._cache) > 100: keys = list(self._cache.keys()) for key in keys[:20]: del self._cache[key] return result except Exception as e: self.logger.error(f"Cached computation failed: {e}") return None