|
|
""" |
|
|
Professional Hair Segmentation Module |
|
|
===================================== |
|
|
|
|
|
This module provides high-quality hair segmentation for video processing |
|
|
using SAM2 + MatAnyone pipeline with comprehensive error handling and fallbacks. |
|
|
|
|
|
Author: BackgroundFX Pro |
|
|
License: MIT |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import logging |
|
|
from typing import Dict, List, Tuple, Optional, Union |
|
|
from pathlib import Path |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from abc import ABC, abstractmethod |
|
|
|
|
|
|
|
|
os.environ['OMP_NUM_THREADS'] = '4' |
|
|
os.environ['MKL_NUM_THREADS'] = '4' |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class SegmentationResult: |
|
|
"""Result container for hair segmentation""" |
|
|
mask: np.ndarray |
|
|
confidence: float |
|
|
coverage_percent: float |
|
|
asymmetry_score: float |
|
|
processing_time: float |
|
|
fallback_used: bool |
|
|
quality_score: float |
|
|
error_message: Optional[str] = None |
|
|
|
|
|
class BaseSegmentationModel(ABC): |
|
|
"""Abstract base class for segmentation models""" |
|
|
|
|
|
@abstractmethod |
|
|
def initialize(self) -> bool: |
|
|
"""Initialize the model""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def segment(self, frame: np.ndarray) -> np.ndarray: |
|
|
"""Segment hair in frame""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def get_model_name(self) -> str: |
|
|
"""Get model name for logging""" |
|
|
pass |
|
|
|
|
|
class SAM2Model(BaseSegmentationModel): |
|
|
"""SAM2 segmentation model wrapper""" |
|
|
|
|
|
def __init__(self, model_path: Optional[str] = None, device: str = 'auto'): |
|
|
self.model_path = model_path |
|
|
self.device = self._get_best_device(device) |
|
|
self.predictor = None |
|
|
self.initialized = False |
|
|
|
|
|
def _get_best_device(self, device: str) -> str: |
|
|
"""Determine best available device""" |
|
|
if device == 'auto': |
|
|
return 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
return device |
|
|
|
|
|
def initialize(self) -> bool: |
|
|
"""Initialize SAM2 model""" |
|
|
try: |
|
|
logger.info("🤖 Initializing SAM2 model...") |
|
|
|
|
|
|
|
|
try: |
|
|
from sam2.build_sam import build_sam2_video_predictor |
|
|
except ImportError: |
|
|
logger.error("SAM2 not found. Please install SAM2.") |
|
|
return False |
|
|
|
|
|
|
|
|
if self.model_path and Path(self.model_path).exists(): |
|
|
self.predictor = build_sam2_video_predictor(self.model_path, device=self.device) |
|
|
else: |
|
|
|
|
|
self.predictor = build_sam2_video_predictor("sam2_hiera_large.pt", device=self.device) |
|
|
|
|
|
self.initialized = True |
|
|
logger.info(f"✅ SAM2 initialized on {self.device}") |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ SAM2 initialization failed: {e}") |
|
|
return False |
|
|
|
|
|
def segment(self, frame: np.ndarray) -> np.ndarray: |
|
|
"""Segment using SAM2""" |
|
|
if not self.initialized: |
|
|
raise RuntimeError("SAM2 model not initialized") |
|
|
|
|
|
try: |
|
|
|
|
|
if len(frame.shape) == 3: |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
else: |
|
|
frame_rgb = frame |
|
|
|
|
|
|
|
|
self.predictor.set_image(frame_rgb) |
|
|
|
|
|
|
|
|
height, width = frame_rgb.shape[:2] |
|
|
center_point = np.array([[width//2, height//2]]) |
|
|
|
|
|
|
|
|
masks, scores, _ = self.predictor.predict( |
|
|
point_coords=center_point, |
|
|
point_labels=np.array([1]) |
|
|
) |
|
|
|
|
|
|
|
|
if len(masks) > 0: |
|
|
best_mask_idx = np.argmax(scores) |
|
|
return masks[best_mask_idx].astype(np.float32) |
|
|
else: |
|
|
return np.zeros((height, width), dtype=np.float32) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"SAM2 segmentation failed: {e}") |
|
|
raise |
|
|
|
|
|
def get_model_name(self) -> str: |
|
|
return "SAM2" |
|
|
|
|
|
class MatAnyoneModel(BaseSegmentationModel): |
|
|
"""MatAnyone model wrapper with quality checking""" |
|
|
|
|
|
def __init__(self, use_hf_api: bool = True, hf_token: Optional[str] = None): |
|
|
self.use_hf_api = use_hf_api |
|
|
self.hf_token = hf_token |
|
|
self.client = None |
|
|
self.processor = None |
|
|
self.initialized = False |
|
|
self.quality_threshold = 0.3 |
|
|
|
|
|
def initialize(self) -> bool: |
|
|
"""Initialize MatAnyone model""" |
|
|
try: |
|
|
logger.info("🎭 Initializing MatAnyone model...") |
|
|
|
|
|
if self.use_hf_api: |
|
|
from gradio_client import Client |
|
|
self.client = Client("PeiqingYang/MatAnyone", hf_token=self.hf_token) |
|
|
logger.info("✅ MatAnyone HF API initialized") |
|
|
else: |
|
|
|
|
|
logger.warning("Local MatAnyone not implemented yet") |
|
|
return False |
|
|
|
|
|
self.initialized = True |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"❌ MatAnyone initialization failed: {e}") |
|
|
return False |
|
|
|
|
|
def segment(self, frame: np.ndarray) -> np.ndarray: |
|
|
"""MatAnyone is primarily for matting, not segmentation""" |
|
|
raise NotImplementedError("MatAnyone is used for matting, not direct segmentation") |
|
|
|
|
|
def matte(self, image: np.ndarray, trimap: np.ndarray) -> np.ndarray: |
|
|
"""Apply matting using MatAnyone""" |
|
|
if not self.initialized: |
|
|
raise RuntimeError("MatAnyone model not initialized") |
|
|
|
|
|
try: |
|
|
|
|
|
import tempfile |
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as img_file: |
|
|
cv2.imwrite(img_file.name, image) |
|
|
img_path = img_file.name |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tri_file: |
|
|
cv2.imwrite(tri_file.name, trimap) |
|
|
tri_path = tri_file.name |
|
|
|
|
|
|
|
|
if self.use_hf_api: |
|
|
result = self._process_hf_api(img_path, tri_path) |
|
|
else: |
|
|
result = self._process_local(img_path, tri_path) |
|
|
|
|
|
|
|
|
os.unlink(img_path) |
|
|
os.unlink(tri_path) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"MatAnyone matting failed: {e}") |
|
|
raise |
|
|
|
|
|
def _process_hf_api(self, image_path: str, trimap_path: str) -> np.ndarray: |
|
|
"""Process using HuggingFace API""" |
|
|
try: |
|
|
result = self.client.predict( |
|
|
image=image_path, |
|
|
trimap=trimap_path, |
|
|
api_name="/predict" |
|
|
) |
|
|
|
|
|
|
|
|
if isinstance(result, str): |
|
|
result_image = cv2.imread(result) |
|
|
return result_image |
|
|
else: |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"HF API processing failed: {e}") |
|
|
raise |
|
|
|
|
|
def _process_local(self, image_path: str, trimap_path: str) -> np.ndarray: |
|
|
"""Process locally - placeholder for implementation""" |
|
|
raise NotImplementedError("Local MatAnyone processing not implemented") |
|
|
|
|
|
def get_model_name(self) -> str: |
|
|
return "MatAnyone" |
|
|
|
|
|
class TraditionalCVModel(BaseSegmentationModel): |
|
|
"""Traditional computer vision fallback""" |
|
|
|
|
|
def __init__(self): |
|
|
self.initialized = False |
|
|
|
|
|
def initialize(self) -> bool: |
|
|
"""Initialize traditional CV methods""" |
|
|
self.initialized = True |
|
|
return True |
|
|
|
|
|
def segment(self, frame: np.ndarray) -> np.ndarray: |
|
|
"""Traditional hair segmentation using color and texture""" |
|
|
try: |
|
|
|
|
|
hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV) |
|
|
lab = cv2.cvtColor(frame, cv2.COLOR_BGR2LAB) |
|
|
|
|
|
|
|
|
hair_mask_hsv = self._detect_hair_hsv(hsv) |
|
|
hair_mask_lab = self._detect_hair_lab(lab) |
|
|
|
|
|
|
|
|
combined_mask = cv2.bitwise_or(hair_mask_hsv, hair_mask_lab) |
|
|
|
|
|
|
|
|
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) |
|
|
combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_CLOSE, kernel) |
|
|
combined_mask = cv2.morphologyEx(combined_mask, cv2.MORPH_OPEN, kernel) |
|
|
|
|
|
return combined_mask.astype(np.float32) / 255.0 |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Traditional CV segmentation failed: {e}") |
|
|
raise |
|
|
|
|
|
def _detect_hair_hsv(self, hsv: np.ndarray) -> np.ndarray: |
|
|
"""Detect hair in HSV color space""" |
|
|
|
|
|
ranges = [ |
|
|
|
|
|
([0, 0, 0], [180, 255, 80]), |
|
|
|
|
|
([8, 50, 20], [25, 255, 200]), |
|
|
|
|
|
([15, 30, 100], [35, 255, 255]) |
|
|
] |
|
|
|
|
|
masks = [] |
|
|
for lower, upper in ranges: |
|
|
mask = cv2.inRange(hsv, np.array(lower), np.array(upper)) |
|
|
masks.append(mask) |
|
|
|
|
|
|
|
|
final_mask = masks[0] |
|
|
for mask in masks[1:]: |
|
|
final_mask = cv2.bitwise_or(final_mask, mask) |
|
|
|
|
|
return final_mask |
|
|
|
|
|
def _detect_hair_lab(self, lab: np.ndarray) -> np.ndarray: |
|
|
"""Detect hair in LAB color space""" |
|
|
l_channel = lab[:, :, 0] |
|
|
hair_mask = cv2.inRange(l_channel, 0, 120) |
|
|
return hair_mask |
|
|
|
|
|
def get_model_name(self) -> str: |
|
|
return "TraditionalCV" |
|
|
|
|
|
class TemporalSmoother: |
|
|
"""Temporal smoothing for video sequences""" |
|
|
|
|
|
def __init__(self, smoothing_factor: float = 0.7, change_threshold: float = 0.05): |
|
|
self.smoothing_factor = smoothing_factor |
|
|
self.change_threshold = change_threshold |
|
|
self.previous_mask = None |
|
|
self.correction_count = 0 |
|
|
self.total_frames = 0 |
|
|
|
|
|
def smooth(self, current_mask: np.ndarray) -> Tuple[np.ndarray, bool]: |
|
|
"""Apply temporal smoothing""" |
|
|
self.total_frames += 1 |
|
|
corrected = False |
|
|
|
|
|
if self.previous_mask is not None: |
|
|
|
|
|
diff = np.mean(np.abs(current_mask - self.previous_mask)) |
|
|
|
|
|
if diff > self.change_threshold: |
|
|
|
|
|
smoothed_mask = (self.smoothing_factor * current_mask + |
|
|
(1 - self.smoothing_factor) * self.previous_mask) |
|
|
self.correction_count += 1 |
|
|
corrected = True |
|
|
else: |
|
|
smoothed_mask = current_mask |
|
|
else: |
|
|
smoothed_mask = current_mask |
|
|
|
|
|
self.previous_mask = smoothed_mask.copy() |
|
|
return smoothed_mask, corrected |
|
|
|
|
|
def get_correction_ratio(self) -> float: |
|
|
"""Get ratio of frames that needed correction""" |
|
|
return self.correction_count / max(self.total_frames, 1) |
|
|
|
|
|
class HairSegmentationPipeline: |
|
|
"""Main hair segmentation pipeline with multiple models and fallbacks""" |
|
|
|
|
|
def __init__(self, config: Optional[Dict] = None): |
|
|
self.config = config or {} |
|
|
self.models = {} |
|
|
self.active_model = None |
|
|
self.fallback_models = [] |
|
|
self.temporal_smoother = TemporalSmoother() |
|
|
self.initialized = False |
|
|
|
|
|
|
|
|
self._setup_models() |
|
|
|
|
|
def _setup_models(self): |
|
|
"""Setup available models""" |
|
|
try: |
|
|
|
|
|
sam2_model = SAM2Model( |
|
|
model_path=self.config.get('sam2_model_path'), |
|
|
device=self.config.get('device', 'auto') |
|
|
) |
|
|
self.models['sam2'] = sam2_model |
|
|
|
|
|
|
|
|
matanyone_model = MatAnyoneModel( |
|
|
use_hf_api=self.config.get('use_hf_api', True), |
|
|
hf_token=self.config.get('hf_token') |
|
|
) |
|
|
self.models['matanyone'] = matanyone_model |
|
|
|
|
|
|
|
|
traditional_model = TraditionalCVModel() |
|
|
self.models['traditional'] = traditional_model |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Model setup failed: {e}") |
|
|
|
|
|
def initialize(self, preferred_model: str = 'sam2') -> bool: |
|
|
"""Initialize the pipeline""" |
|
|
logger.info("🚀 Initializing Hair Segmentation Pipeline...") |
|
|
|
|
|
|
|
|
if preferred_model in self.models: |
|
|
if self.models[preferred_model].initialize(): |
|
|
self.active_model = preferred_model |
|
|
logger.info(f"✅ Primary model {preferred_model} initialized") |
|
|
else: |
|
|
logger.warning(f"⚠️ Primary model {preferred_model} failed") |
|
|
|
|
|
|
|
|
for model_name, model in self.models.items(): |
|
|
if model_name != self.active_model: |
|
|
if model.initialize(): |
|
|
self.fallback_models.append(model_name) |
|
|
logger.info(f"✅ Fallback model {model_name} ready") |
|
|
|
|
|
|
|
|
if self.active_model or self.fallback_models: |
|
|
self.initialized = True |
|
|
logger.info(f"🎯 Pipeline ready - Active: {self.active_model}, Fallbacks: {self.fallback_models}") |
|
|
return True |
|
|
else: |
|
|
logger.error("❌ No working models available") |
|
|
return False |
|
|
|
|
|
def segment_frame(self, frame: np.ndarray, |
|
|
apply_temporal_smoothing: bool = True) -> SegmentationResult: |
|
|
"""Segment hair in a single frame""" |
|
|
if not self.initialized: |
|
|
raise RuntimeError("Pipeline not initialized") |
|
|
|
|
|
import time |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
mask, model_used, error_msg = self._try_segment_with_model(frame, self.active_model) |
|
|
|
|
|
|
|
|
if mask is None: |
|
|
for fallback_model in self.fallback_models: |
|
|
mask, model_used, error_msg = self._try_segment_with_model(frame, fallback_model) |
|
|
if mask is not None: |
|
|
break |
|
|
|
|
|
if mask is None: |
|
|
|
|
|
h, w = frame.shape[:2] |
|
|
mask = np.zeros((h, w), dtype=np.float32) |
|
|
model_used = "none" |
|
|
error_msg = "All models failed" |
|
|
|
|
|
|
|
|
corrected = False |
|
|
if apply_temporal_smoothing: |
|
|
mask, corrected = self.temporal_smoother.smooth(mask) |
|
|
|
|
|
|
|
|
processing_time = time.time() - start_time |
|
|
confidence = self._calculate_confidence(mask) |
|
|
coverage = self._calculate_coverage(mask) |
|
|
asymmetry = self._calculate_asymmetry(mask) |
|
|
quality = self._calculate_quality(mask) |
|
|
|
|
|
return SegmentationResult( |
|
|
mask=mask, |
|
|
confidence=confidence, |
|
|
coverage_percent=coverage, |
|
|
asymmetry_score=asymmetry, |
|
|
processing_time=processing_time, |
|
|
fallback_used=(model_used != self.active_model), |
|
|
quality_score=quality, |
|
|
error_message=error_msg |
|
|
) |
|
|
|
|
|
def _try_segment_with_model(self, frame: np.ndarray, model_name: str) -> Tuple[Optional[np.ndarray], str, Optional[str]]: |
|
|
"""Try to segment with a specific model""" |
|
|
if model_name not in self.models: |
|
|
return None, model_name, f"Model {model_name} not available" |
|
|
|
|
|
try: |
|
|
mask = self.models[model_name].segment(frame) |
|
|
return mask, model_name, None |
|
|
except Exception as e: |
|
|
error_msg = f"Model {model_name} failed: {str(e)}" |
|
|
logger.warning(error_msg) |
|
|
return None, model_name, error_msg |
|
|
|
|
|
def _calculate_confidence(self, mask: np.ndarray) -> float: |
|
|
"""Calculate mask confidence using OpenCV instead of skimage""" |
|
|
|
|
|
edges = cv2.Canny((mask * 255).astype(np.uint8), 50, 150) |
|
|
edge_ratio = np.sum(edges > 0) / mask.size |
|
|
|
|
|
|
|
|
grad_x = cv2.Sobel(mask, cv2.CV_64F, 1, 0, ksize=3) |
|
|
grad_y = cv2.Sobel(mask, cv2.CV_64F, 0, 1, ksize=3) |
|
|
gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2) |
|
|
smoothness = 1.0 / (1.0 + np.std(gradient_magnitude)) |
|
|
|
|
|
return min(edge_ratio * 0.3 + smoothness * 0.7, 1.0) |
|
|
|
|
|
def _calculate_coverage(self, mask: np.ndarray) -> float: |
|
|
"""Calculate hair coverage percentage""" |
|
|
return (np.sum(mask > 0.5) / mask.size) * 100 |
|
|
|
|
|
def _calculate_asymmetry(self, mask: np.ndarray) -> float: |
|
|
"""Calculate left-right asymmetry score""" |
|
|
h, w = mask.shape[:2] |
|
|
center_x = w // 2 |
|
|
|
|
|
left_half = mask[:, :center_x] |
|
|
right_half = np.fliplr(mask[:, center_x:]) |
|
|
|
|
|
min_width = min(left_half.shape[1], right_half.shape[1]) |
|
|
left_half = left_half[:, :min_width] |
|
|
right_half = right_half[:, :min_width] |
|
|
|
|
|
return np.mean(np.abs(left_half - right_half)) |
|
|
|
|
|
def _calculate_quality(self, mask: np.ndarray) -> float: |
|
|
"""Calculate overall mask quality""" |
|
|
|
|
|
confidence = self._calculate_confidence(mask) |
|
|
coverage = self._calculate_coverage(mask) / 100.0 |
|
|
asymmetry_penalty = 1.0 - min(self._calculate_asymmetry(mask), 1.0) |
|
|
|
|
|
return (confidence * 0.5 + coverage * 0.3 + asymmetry_penalty * 0.2) |
|
|
|
|
|
def get_pipeline_stats(self) -> Dict: |
|
|
"""Get pipeline performance statistics""" |
|
|
return { |
|
|
'active_model': self.active_model, |
|
|
'fallback_models': self.fallback_models, |
|
|
'temporal_correction_ratio': self.temporal_smoother.get_correction_ratio(), |
|
|
'total_frames_processed': self.temporal_smoother.total_frames, |
|
|
'corrections_applied': self.temporal_smoother.correction_count |
|
|
} |
|
|
|
|
|
|
|
|
def create_pipeline(config: Optional[Dict] = None) -> HairSegmentationPipeline: |
|
|
"""Create and initialize hair segmentation pipeline""" |
|
|
pipeline = HairSegmentationPipeline(config) |
|
|
pipeline.initialize() |
|
|
return pipeline |
|
|
|
|
|
def segment_image(image_path: str, config: Optional[Dict] = None) -> SegmentationResult: |
|
|
"""Segment hair in a single image""" |
|
|
pipeline = create_pipeline(config) |
|
|
frame = cv2.imread(image_path) |
|
|
return pipeline.segment_frame(frame) |
|
|
|
|
|
def segment_video_frames(video_frames: List[np.ndarray], |
|
|
config: Optional[Dict] = None) -> List[SegmentationResult]: |
|
|
"""Segment hair in multiple video frames""" |
|
|
pipeline = create_pipeline(config) |
|
|
results = [] |
|
|
|
|
|
for frame in video_frames: |
|
|
result = pipeline.segment_frame(frame) |
|
|
results.append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
config = { |
|
|
'sam2_model_path': None, |
|
|
'device': 'auto', |
|
|
'use_hf_api': True, |
|
|
'hf_token': None |
|
|
} |
|
|
|
|
|
|
|
|
pipeline = create_pipeline(config) |
|
|
|
|
|
|
|
|
test_frame = np.random.randint(0, 255, (480, 640, 3), dtype=np.uint8) |
|
|
|
|
|
|
|
|
result = pipeline.segment_frame(test_frame) |
|
|
|
|
|
|
|
|
print(f"Segmentation Results:") |
|
|
print(f" Coverage: {result.coverage_percent:.1f}%") |
|
|
print(f" Confidence: {result.confidence:.3f}") |
|
|
print(f" Quality: {result.quality_score:.3f}") |
|
|
print(f" Processing time: {result.processing_time:.2f}s") |
|
|
print(f" Fallback used: {result.fallback_used}") |
|
|
|
|
|
|
|
|
stats = pipeline.get_pipeline_stats() |
|
|
print(f"\nPipeline Stats:") |
|
|
print(f" Active model: {stats['active_model']}") |
|
|
print(f" Fallbacks: {stats['fallback_models']}") |
|
|
print(f" Correction ratio: {stats['temporal_correction_ratio']:.3f}") |