File size: 16,935 Bytes
a0ffb03
 
 
53fdc22
a0ffb03
 
 
7b9f1c5
53fdc22
a0ffb03
53fdc22
 
69083e6
a0ffb03
53fdc22
a0ffb03
53fdc22
 
 
 
 
 
 
 
 
 
 
 
7b9f1c5
53fdc22
 
 
 
7b9f1c5
53fdc22
a0ffb03
 
53fdc22
 
 
a0ffb03
 
53fdc22
 
 
 
 
 
 
 
 
 
a0ffb03
53fdc22
 
 
 
 
 
 
 
a0ffb03
53fdc22
 
a0ffb03
53fdc22
a0ffb03
53fdc22
 
 
 
 
 
 
 
 
 
a0ffb03
 
53fdc22
7b9f1c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53fdc22
 
 
 
 
 
 
69083e6
a8a12b2
 
 
69083e6
 
 
 
 
 
a8a12b2
69083e6
e94d263
 
 
 
 
 
 
 
69083e6
 
 
e94d263
 
 
 
69083e6
a8a12b2
69083e6
a8a12b2
69083e6
e94d263
 
 
7b9f1c5
69083e6
 
7b9f1c5
 
 
 
 
 
a8a12b2
69083e6
7b9f1c5
 
 
69083e6
7b9f1c5
69083e6
 
7b9f1c5
69083e6
 
 
 
7b9f1c5
 
69083e6
7b9f1c5
69083e6
7b9f1c5
 
69083e6
 
7b9f1c5
69083e6
 
7b9f1c5
69083e6
 
 
 
 
 
 
 
53fdc22
 
7b9f1c5
53fdc22
7b9f1c5
 
 
 
 
 
 
a8a12b2
 
 
7b9f1c5
 
 
a8a12b2
7b9f1c5
 
 
 
 
 
a8a12b2
 
7b9f1c5
 
 
e94d263
 
7b9f1c5
e94d263
 
 
 
 
 
 
 
 
7b9f1c5
a8a12b2
7b9f1c5
a8a12b2
7b9f1c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8a12b2
7b9f1c5
 
a8a12b2
7b9f1c5
 
a8a12b2
7b9f1c5
 
 
 
 
 
 
 
e94d263
 
 
 
 
 
 
7b9f1c5
a8a12b2
 
7b9f1c5
a8a12b2
7b9f1c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53fdc22
7b9f1c5
 
53fdc22
7b9f1c5
 
53fdc22
7b9f1c5
 
53fdc22
7b9f1c5
53fdc22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b9f1c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53fdc22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
#!/usr/bin/env python3
"""
utils.refinement
High-quality mask refinement for BackgroundFX Pro.
"""

from __future__ import annotations
from typing import Any, Optional, Tuple, List
import logging

import cv2
import numpy as np
import torch

log = logging.getLogger(__name__)

# ============================================================================
# CUSTOM EXCEPTION
# ============================================================================
class MaskRefinementError(Exception):
    """Custom exception for mask refinement errors"""
    pass

# ============================================================================
# EXPORTS
# ============================================================================
__all__ = [
    "refine_mask_hq",
    "refine_masks_batch",
    "MaskRefinementError",
]

# ============================================================================
# MAIN API - SINGLE FRAME
# ============================================================================
def refine_mask_hq(
    image: np.ndarray,
    mask: np.ndarray,
    matanyone_model: Optional[Any] = None,
    fallback_enabled: bool = True
) -> np.ndarray:
    """
    High-quality mask refinement with multiple strategies.
    
    Args:
        image: Original BGR image
        mask: Initial binary mask (0/255)
        matanyone_model: Optional MatAnyone model for AI refinement
        fallback_enabled: Whether to use fallback methods if AI fails
    
    Returns:
        Refined binary mask (0/255)
    """
    if image is None or mask is None:
        raise MaskRefinementError("Invalid input image or mask")
    
    if image.shape[:2] != mask.shape[:2]:
        raise MaskRefinementError(f"Image shape {image.shape[:2]} doesn't match mask shape {mask.shape[:2]}")
    
    # Try AI-based refinement first if model available
    if matanyone_model is not None:
        try:
            refined = _refine_with_matanyone(image, mask, matanyone_model)
            if _validate_refined_mask(refined, mask):
                return refined
            log.warning("MatAnyone refinement failed validation")
        except Exception as e:
            log.warning(f"MatAnyone refinement failed: {e}")
    
    # Fallback to classical refinement methods
    if fallback_enabled:
        try:
            return _classical_refinement(image, mask)
        except Exception as e:
            log.warning(f"Classical refinement failed: {e}")
            return mask  # Return original if all fails
    
    return mask

# ============================================================================
# BATCH PROCESSING FOR TEMPORAL CONSISTENCY
# ============================================================================
def refine_masks_batch(
    frames: List[np.ndarray],
    masks: List[np.ndarray],
    matanyone_model: Optional[Any] = None,
    fallback_enabled: bool = True
) -> List[np.ndarray]:
    """
    Refine multiple masks using MatAnyone's temporal consistency.
    
    Args:
        frames: List of BGR images
        masks: List of initial binary masks
        matanyone_model: MatAnyone InferenceCore model
        fallback_enabled: Whether to use fallback methods
    
    Returns:
        List of refined binary masks
    """
    if not frames or not masks:
        return masks
    
    if len(frames) != len(masks):
        raise MaskRefinementError(f"Frame count {len(frames)} doesn't match mask count {len(masks)}")
    
    if matanyone_model is not None:
        try:
            refined = _refine_batch_with_matanyone(frames, masks, matanyone_model)
            # Validate all masks
            if all(_validate_refined_mask(r, m) for r, m in zip(refined, masks)):
                return refined
            log.warning("Batch MatAnyone refinement failed validation")
        except Exception as e:
            log.warning(f"Batch MatAnyone refinement failed: {e}")
    
    # Fallback to frame-by-frame classical refinement
    if fallback_enabled:
        return [_classical_refinement(f, m) for f, m in zip(frames, masks)]
    
    return masks

# ============================================================================
# AI-BASED REFINEMENT - SINGLE FRAME
# ============================================================================
def _refine_with_matanyone(
    image: np.ndarray,
    mask: np.ndarray,
    model: Any
) -> np.ndarray:
    """Use MatAnyone model for mask refinement."""
    try:
        # Set device to GPU (Tesla T4 on cuda:0)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        # Convert BGR to RGB and normalize
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w = image_rgb.shape[:2]
        
        # Convert to torch tensor format (C, H, W) and normalize to [0, 1]
        image_tensor = torch.from_numpy(image_rgb).permute(2, 0, 1).float() / 255.0
        image_tensor = image_tensor.unsqueeze(0).to(device)  # Add batch dimension and move to GPU
        
        # CRITICAL: Ensure mask is 2D before processing
        if mask.ndim == 3:
            # Convert multi-channel to single channel
            if mask.shape[2] == 3:
                mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
            else:
                mask = mask[:, :, 0]
        
        # Ensure mask is binary uint8
        if mask.dtype != np.uint8:
            mask = (mask * 255).astype(np.uint8) if mask.max() <= 1 else mask.astype(np.uint8)
        
        # Final verification that mask is 2D
        assert mask.ndim == 2, f"Mask must be 2D after conversion, got shape {mask.shape}"
        assert mask.shape == (h, w), f"Mask shape {mask.shape} doesn't match image shape ({h}, {w})"
        
        # Convert mask to tensor and move to GPU
        mask_tensor = torch.from_numpy(mask).float() / 255.0
        mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)  # (1, 1, H, W) on GPU
        
        # Verify tensor dimensions
        assert mask_tensor.shape == (1, 1, h, w), f"Mask tensor wrong shape: {mask_tensor.shape}, expected (1, 1, {h}, {w})"
        
        # Try different methods on InferenceCore
        result = None
        
        # Log available methods for debugging
        methods = [m for m in dir(model) if not m.startswith('_')]
        log.debug(f"MatAnyone InferenceCore methods: {methods}")
        
        with torch.no_grad():
            if hasattr(model, 'step'):
                # Step method for iterative processing
                result = model.step(image_tensor, mask_tensor)
            elif hasattr(model, 'process_frame'):
                result = model.process_frame(image_tensor, mask_tensor)
            elif hasattr(model, 'forward'):
                result = model.forward(image_tensor, mask_tensor)
            elif hasattr(model, '__call__'):
                result = model(image_tensor, mask_tensor)
            else:
                raise MaskRefinementError(f"No recognized method. Available: {methods}")
        
        if result is None:
            raise MaskRefinementError("MatAnyone returned None")
        
        # Extract alpha matte from result
        alpha = _extract_alpha_from_result(result)
        
        # Convert back to numpy and resize if needed
        if isinstance(alpha, torch.Tensor):
            alpha = alpha.squeeze().cpu().numpy()
        
        if alpha.ndim == 3:
            alpha = alpha[0] if alpha.shape[0] == 1 else alpha.mean(axis=0)
        
        if alpha.dtype != np.uint8:
            alpha = (alpha * 255).clip(0, 255).astype(np.uint8)
        
        if alpha.shape != (h, w):
            alpha = cv2.resize(alpha, (w, h), interpolation=cv2.INTER_LINEAR)
        
        return _process_mask(alpha)
        
    except Exception as e:
        log.error(f"MatAnyone processing error: {str(e)}")
        raise MaskRefinementError(f"MatAnyone processing failed: {str(e)}")

# ============================================================================
# AI-BASED REFINEMENT - BATCH
# ============================================================================
def _refine_batch_with_matanyone(
    frames: List[np.ndarray],
    masks: List[np.ndarray],
    model: Any
) -> List[np.ndarray]:
    """Process batch of frames through MatAnyone for temporal consistency."""
    try:
        # Set device to GPU (Tesla T4 on cuda:0)
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        batch_size = len(frames)
        h, w = frames[0].shape[:2]
        
        # Convert frames to tensor batch and move to GPU
        frame_tensors = []
        for frame in frames:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            tensor = torch.from_numpy(frame_rgb).permute(2, 0, 1).float() / 255.0
            frame_tensors.append(tensor)
        
        # Stack into batch (N, C, H, W) and move to GPU
        batch_tensor = torch.stack(frame_tensors).to(device)
        
        # Prepare first mask for initialization
        first_mask = masks[0]
        
        # CRITICAL: Ensure first mask is 2D
        if first_mask.ndim == 3:
            if first_mask.shape[2] == 3:
                first_mask = cv2.cvtColor(first_mask, cv2.COLOR_BGR2GRAY)
            else:
                first_mask = first_mask[:, :, 0]
        
        if first_mask.dtype != np.uint8:
            first_mask = (first_mask * 255).astype(np.uint8) if first_mask.max() <= 1 else first_mask.astype(np.uint8)
        
        assert first_mask.ndim == 2, f"First mask must be 2D, got shape {first_mask.shape}"
        
        # Convert first mask to tensor and move to GPU
        first_mask_tensor = torch.from_numpy(first_mask).float() / 255.0
        first_mask_tensor = first_mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
        
        refined_masks = []
        
        with torch.no_grad():
            # Check for batch processing methods
            if hasattr(model, 'process_batch'):
                # Direct batch processing
                results = model.process_batch(batch_tensor, first_mask_tensor)
                for result in results:
                    alpha = _extract_alpha_from_result(result)
                    refined_masks.append(_tensor_to_mask(alpha, h, w))
                    
            elif hasattr(model, 'step'):
                # Process frames sequentially with memory
                for i, frame_tensor in enumerate(frame_tensors):
                    frame_on_device = frame_tensor.unsqueeze(0).to(device)
                    if i == 0:
                        # First frame with mask
                        result = model.step(frame_on_device, first_mask_tensor)
                    else:
                        # Subsequent frames use memory from previous
                        result = model.step(frame_on_device, None)
                    
                    alpha = _extract_alpha_from_result(result)
                    refined_masks.append(_tensor_to_mask(alpha, h, w))
                    
            else:
                # Fallback to processing each frame with its mask
                log.warning("MatAnyone batch processing not available, using frame-by-frame")
                for frame_tensor, mask in zip(frame_tensors, masks):
                    # Ensure each mask is 2D
                    if mask.ndim == 3:
                        if mask.shape[2] == 3:
                            mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
                        else:
                            mask = mask[:, :, 0]
                    
                    mask_tensor = torch.from_numpy(mask).float() / 255.0
                    mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).to(device)
                    frame_on_device = frame_tensor.unsqueeze(0).to(device)
                    
                    result = model(frame_on_device, mask_tensor)
                    alpha = _extract_alpha_from_result(result)
                    refined_masks.append(_tensor_to_mask(alpha, h, w))
        
        return refined_masks
        
    except Exception as e:
        log.error(f"Batch MatAnyone processing error: {str(e)}")
        raise MaskRefinementError(f"Batch processing failed: {str(e)}")

# ============================================================================
# HELPER FUNCTIONS
# ============================================================================
def _extract_alpha_from_result(result):
    """Extract alpha matte from various result formats."""
    if isinstance(result, (tuple, list)):
        return result[0] if len(result) > 0 else None
    elif isinstance(result, dict):
        return result.get('alpha', result.get('matte', result.get('mask', None)))
    else:
        return result

def _tensor_to_mask(tensor, target_h, target_w):
    """Convert tensor to numpy mask with proper sizing."""
    if isinstance(tensor, torch.Tensor):
        mask = tensor.squeeze().cpu().numpy()
    else:
        mask = tensor
    
    if mask.ndim == 3:
        mask = mask[0] if mask.shape[0] == 1 else mask.mean(axis=0)
    
    if mask.dtype != np.uint8:
        mask = (mask * 255).clip(0, 255).astype(np.uint8)
    
    if mask.shape != (target_h, target_w):
        mask = cv2.resize(mask, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
    
    return mask

def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
    """Check if refined mask is reasonable."""
    if refined is None or refined.size == 0:
        return False
    
    refined_area = np.sum(refined > 127)
    original_area = np.sum(original > 127)
    
    if refined_area == 0:
        return False
    
    ratio = refined_area / max(original_area, 1)
    return 0.5 <= ratio <= 2.0

def _process_mask(mask: np.ndarray) -> np.ndarray:
    """Convert any mask format to binary 0/255."""
    if mask.dtype == np.float32 or mask.dtype == np.float64:
        if mask.max() <= 1.0:
            mask = (mask * 255).astype(np.uint8)
    
    if mask.dtype != np.uint8:
        mask = mask.astype(np.uint8)
    
    if mask.ndim == 3:
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
    
    _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    return binary

# ============================================================================
# CLASSICAL REFINEMENT
# ============================================================================
def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Apply classical CV techniques for mask refinement."""
    refined = mask.copy()
    
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
    refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
    refined = _edge_aware_smooth(image, refined)
    refined = _feather_edges(refined, radius=3)
    refined = _remove_small_components(refined, min_area_ratio=0.005)
    
    return refined

def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
    """Apply edge-aware smoothing using guided filter."""
    mask_float = mask.astype(np.float32) / 255.0
    radius = 5
    eps = 0.01
    
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
    
    mean_I = cv2.boxFilter(gray, -1, (radius, radius))
    mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
    mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
    
    cov_Ip = mean_Ip - mean_I * mean_p
    mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
    var_I = mean_II - mean_I * mean_I
    
    a = cov_Ip / (var_I + eps)
    b = mean_p - a * mean_I
    
    mean_a = cv2.boxFilter(a, -1, (radius, radius))
    mean_b = cv2.boxFilter(b, -1, (radius, radius))
    
    refined = mean_a * gray + mean_b
    return (refined * 255).clip(0, 255).astype(np.uint8)

def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
    """Slightly blur edges for smoother transitions."""
    if radius <= 0:
        return mask
    
    blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
    _, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
    return binary

def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
    """Remove small disconnected components."""
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
    
    if num_labels <= 1:
        return mask
    
    total_area = mask.shape[0] * mask.shape[1]
    min_area = int(total_area * min_area_ratio)
    
    areas = stats[1:, cv2.CC_STAT_AREA]
    if len(areas) == 0:
        return mask
    
    max_label = np.argmax(areas) + 1
    
    cleaned = np.zeros_like(mask)
    for label in range(1, num_labels):
        if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label:
            cleaned[labels == label] = 255
    
    return cleaned