MogensR commited on
Commit
53fdc22
Β·
1 Parent(s): 151a692

Update utils/refinement.py

Browse files
Files changed (1) hide show
  1. utils/refinement.py +213 -148
utils/refinement.py CHANGED
@@ -1,167 +1,232 @@
1
  #!/usr/bin/env python3
2
  """
3
  utils.refinement
4
- ─────────────────────────────────────────────────────────────────────────────
5
- Single-frame mask refinement for BackgroundFX Pro.
6
-
7
- Public API
8
- ----------
9
- refine_mask_hq(image, mask, matanyone_processor, fallback_enabled=True) -> np.ndarray
10
  """
11
 
12
  from __future__ import annotations
13
- from typing import Any, Tuple, Optional
14
- import logging, cv2, torch, numpy as np
15
-
16
- log = logging.getLogger(__name__)
17
 
18
- # Quality thresholds (same as before)
19
- MIN_AREA_RATIO = 0.015
20
- MAX_AREA_RATIO = 0.97
21
 
22
- # ────────────────────────────────────────────────────────────────────────────
23
- # Public
24
- # ────────────────────────────────────────────────────────────────────────────
25
- __all__ = ["refine_mask_hq"]
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def refine_mask_hq(
28
  image: np.ndarray,
29
- mask: np.ndarray,
30
- matanyone_processor: Any,
31
- fallback_enabled: bool = True,
32
  ) -> np.ndarray:
33
  """
34
- 1) Try MatAnyOne high-quality refinement.
35
- 2) Otherwise OpenCV β€œenhanced” filter.
36
- 3) GrabCut and saliency fallbacks.
37
- Always returns uint8 mask (0/255).
 
 
 
 
 
 
38
  """
39
- mask = _process_mask(mask)
40
-
41
- # 1 β€” MatAnyOne
42
- if matanyone_processor is not None:
 
 
 
 
43
  try:
44
- refined = _matanyone_refine(image, mask, matanyone_processor)
45
- if refined is not None and _validate_mask_quality(refined, image.shape[:2]):
46
  return refined
47
- log.warning("MatAnyOne produced poor mask; fallback")
48
  except Exception as e:
49
- log.warning(f"MatAnyOne error: {e}")
50
-
51
- # 2 β€” OpenCV β€œenhanced” bilateral+guided+MORPH
52
- try:
53
- refined = _opencv_enhance(image, mask)
54
- if _validate_mask_quality(refined, image.shape[:2]):
55
- return refined
56
- except Exception as e:
57
- log.debug(f"OpenCV enhance error: {e}")
58
-
59
- # 3 β€” GrabCut + saliency double-fallback
60
- try:
61
- gc = _refine_with_grabcut(image, mask)
62
- if _validate_mask_quality(gc, image.shape[:2]):
63
- return gc
64
- sal = _refine_with_saliency(image, mask)
65
- if _validate_mask_quality(sal, image.shape[:2]):
66
- return sal
67
- except Exception as e:
68
- log.debug(f"GrabCut/saliency fallback error: {e}")
69
-
70
- # last resort
71
- return mask if fallback_enabled else _opencv_enhance(image, mask)
72
-
73
- # ────────────────────────────────────────────────────────────────────────────
74
- # MatAnyOne wrapper (safe)
75
- # ──────────────────────────────────────────────────��─────────────────────────
76
- def _matanyone_refine(img, mask, proc) -> Optional[np.ndarray]:
77
- if not (hasattr(proc, "step") and hasattr(proc, "output_prob_to_mask")):
78
- return None
79
- # image tensor (C,H,W) float32 0-1
80
- anp = img.astype(np.float32)
81
- if anp.max() > 1: anp /= 255.0
82
- anp = np.transpose(anp, (2,0,1))
83
- img_t = torch.from_numpy(anp).unsqueeze(0).to(proc.device if hasattr(proc,"device") else "cpu")
84
- mask_f = mask.astype(np.float32)/255.0
85
- mask_t = torch.from_numpy(mask_f).unsqueeze(0).to(img_t.device)
86
-
87
- with torch.no_grad():
88
- prob = proc.step(img_t, mask_t, objects=[1])
89
- m = proc.output_prob_to_mask(prob).squeeze().cpu().numpy()
90
- if m.max() <= 1: m *= 255
91
- return m.astype(np.uint8)
92
-
93
- # ────────────────────────────────────────────────────────────────────────────
94
- # OpenCV enhanced filter chain
95
- # ────────────────────────────────────────────────────────────────────────────
96
- def _opencv_enhance(img, mask):
97
- if mask.ndim == 3: mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
98
- if mask.max()<=1: mask = (mask*255).astype(np.uint8)
99
- m = cv2.bilateralFilter(mask, 9, 75, 75)
100
- m = _guided_filter(img, m, r=8, eps=0.2)
101
- m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(5,5)))
102
- m = cv2.morphologyEx(m, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(3,3)))
103
- m = cv2.GaussianBlur(m,(3,3),0.8)
104
- _,m = cv2.threshold(m,127,255,cv2.THRESH_BINARY)
105
- return m
106
-
107
- def _guided_filter(guide, mask, r=8, eps=0.2):
108
- g = cv2.cvtColor(guide, cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
109
- m = mask.astype(np.float32)/255.0
110
- k = 2*r+1
111
- mean_g = cv2.boxFilter(g, -1, (k,k))
112
- mean_m = cv2.boxFilter(m, -1, (k,k))
113
- corr_gm = cv2.boxFilter(g*m, -1, (k,k))
114
- cov = corr_gm - mean_g*mean_m
115
- var_g = cv2.boxFilter(g*g, -1, (k,k)) - mean_g*mean_g
116
- a = cov/(var_g+eps)
117
- b = mean_m - a*mean_g
118
- mean_a = cv2.boxFilter(a, -1, (k,k))
119
- mean_b = cv2.boxFilter(b, -1, (k,k))
120
- out = (mean_a*g+mean_b)*255
121
- return out.astype(np.uint8)
122
-
123
- # ────────────────────────────────────────────────────────────────────────────
124
- # GrabCut & saliency fallbacks
125
- # ────────────────────────────────────────────────────────────────────────────
126
- def _refine_with_grabcut(img, seed):
127
- h,w = img.shape[:2]
128
- gc = np.full((h,w), cv2.GC_PR_BGD, np.uint8)
129
- gc[seed>200] = cv2.GC_FGD
130
- rect = (w//4, h//6, w//2, int(h*0.7))
131
- bgd,fgd = np.zeros((1,65),np.float64), np.zeros((1,65),np.float64)
132
- cv2.grabCut(img, gc, rect, bgd, fgd, 3, cv2.GC_INIT_WITH_MASK)
133
- return np.where((gc==cv2.GC_FGD)|(gc==cv2.GC_PR_FGD),255,0).astype(np.uint8)
134
-
135
- def _refine_with_saliency(img, seed):
136
- sal = _compute_saliency(img)
137
- if sal is None: return seed
138
- high = (sal>0.6).astype(np.uint8)*255
139
- cy,cx = img.shape[0]//2, img.shape[1]//2
140
- if np.any(seed>127):
141
- ys,xs = np.where(seed>127); cy,cx=int(np.mean(ys)),int(np.mean(xs))
142
- ff = high.copy(); cv2.floodFill(ff,None,(cx,cy),255,loDiff=5,upDiff=5)
143
- return ff
144
-
145
- def _compute_saliency(img):
146
- try:
147
- if hasattr(cv2,"saliency"):
148
- s=cv2.saliency.StaticSaliencySpectralResidual_create()
149
- ok,sm=s.computeSaliency(img)
150
- if ok: return (sm-sm.min())/max(1e-6,sm.max()-sm.min())
151
- except Exception: pass
152
- return None
153
-
154
- # ────────────────────────────────────────────────────────────────────────────
155
- # Helpers
156
- # ────────────────────────────────────────────────────────────────────────────
157
- def _process_mask(mask):
158
- if mask.ndim==3: mask=cv2.cvtColor(mask,cv2.COLOR_BGR2GRAY)
159
- if mask.dtype!=np.uint8:
160
- mask = (mask*255).astype(np.uint8) if mask.max()<=1 else mask.astype(np.uint8)
161
- _,mask=cv2.threshold(mask,127,255,cv2.THRESH_BINARY)
162
  return mask
163
 
164
- def _validate_mask_quality(mask, shape: Tuple[int,int]) -> bool:
165
- h,w = shape
166
- ratio = np.sum(mask>127)/(h*w)
167
- return MIN_AREA_RATIO <= ratio <= MAX_AREA_RATIO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python3
2
  """
3
  utils.refinement
4
+ High-quality mask refinement for BackgroundFX Pro.
 
 
 
 
 
5
  """
6
 
7
  from __future__ import annotations
8
+ from typing import Any, Optional, Tuple
9
+ import logging
 
 
10
 
11
+ import cv2
12
+ import numpy as np
 
13
 
14
+ log = logging.getLogger(__name__)
 
 
 
15
 
16
+ # ============================================================================
17
+ # CUSTOM EXCEPTION
18
+ # ============================================================================
19
+ class MaskRefinementError(Exception):
20
+ """Custom exception for mask refinement errors"""
21
+ pass
22
+
23
+ # ============================================================================
24
+ # EXPORTS
25
+ # ============================================================================
26
+ __all__ = [
27
+ "refine_mask_hq",
28
+ "MaskRefinementError",
29
+ ]
30
+
31
+ # ============================================================================
32
+ # MAIN API
33
+ # ============================================================================
34
  def refine_mask_hq(
35
  image: np.ndarray,
36
+ mask: np.ndarray,
37
+ matanyone_model: Optional[Any] = None,
38
+ fallback_enabled: bool = True
39
  ) -> np.ndarray:
40
  """
41
+ High-quality mask refinement with multiple strategies.
42
+
43
+ Args:
44
+ image: Original BGR image
45
+ mask: Initial binary mask (0/255)
46
+ matanyone_model: Optional MatAnyone model for AI refinement
47
+ fallback_enabled: Whether to use fallback methods if AI fails
48
+
49
+ Returns:
50
+ Refined binary mask (0/255)
51
  """
52
+ if image is None or mask is None:
53
+ raise MaskRefinementError("Invalid input image or mask")
54
+
55
+ if image.shape[:2] != mask.shape[:2]:
56
+ raise MaskRefinementError(f"Image shape {image.shape[:2]} doesn't match mask shape {mask.shape[:2]}")
57
+
58
+ # Try AI-based refinement first if model available
59
+ if matanyone_model is not None:
60
  try:
61
+ refined = _refine_with_matanyone(image, mask, matanyone_model)
62
+ if _validate_refined_mask(refined, mask):
63
  return refined
64
+ log.warning("MatAnyone refinement failed validation")
65
  except Exception as e:
66
+ log.warning(f"MatAnyone refinement failed: {e}")
67
+
68
+ # Fallback to classical refinement methods
69
+ if fallback_enabled:
70
+ try:
71
+ return _classical_refinement(image, mask)
72
+ except Exception as e:
73
+ log.warning(f"Classical refinement failed: {e}")
74
+ return mask # Return original if all fails
75
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  return mask
77
 
78
+ # ============================================================================
79
+ # AI-BASED REFINEMENT
80
+ # ============================================================================
81
+ def _refine_with_matanyone(
82
+ image: np.ndarray,
83
+ mask: np.ndarray,
84
+ model: Any
85
+ ) -> np.ndarray:
86
+ """Use MatAnyone model for mask refinement."""
87
+ # Check if model has expected interface
88
+ if hasattr(model, 'process'):
89
+ result = model.process(image, mask)
90
+ elif hasattr(model, 'refine'):
91
+ result = model.refine(image, mask)
92
+ elif callable(model):
93
+ result = model(image, mask)
94
+ else:
95
+ raise MaskRefinementError("MatAnyone model doesn't have expected interface")
96
+
97
+ # Convert result to binary mask
98
+ if result is None:
99
+ raise MaskRefinementError("MatAnyone returned None")
100
+
101
+ return _process_mask(result)
102
+
103
+ # ============================================================================
104
+ # CLASSICAL REFINEMENT
105
+ # ============================================================================
106
+ def _classical_refinement(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
107
+ """Apply classical CV techniques for mask refinement."""
108
+ refined = mask.copy()
109
+
110
+ # 1. Morphological operations to clean up
111
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
112
+ refined = cv2.morphologyEx(refined, cv2.MORPH_CLOSE, kernel)
113
+ refined = cv2.morphologyEx(refined, cv2.MORPH_OPEN, kernel)
114
+
115
+ # 2. Edge-aware smoothing
116
+ refined = _edge_aware_smooth(image, refined)
117
+
118
+ # 3. Feather edges slightly
119
+ refined = _feather_edges(refined, radius=3)
120
+
121
+ # 4. Remove small disconnected components
122
+ refined = _remove_small_components(refined, min_area_ratio=0.005)
123
+
124
+ return refined
125
+
126
+ # ============================================================================
127
+ # HELPER FUNCTIONS
128
+ # ============================================================================
129
+ def _validate_refined_mask(refined: np.ndarray, original: np.ndarray) -> bool:
130
+ """Check if refined mask is reasonable."""
131
+ if refined is None or refined.size == 0:
132
+ return False
133
+
134
+ # Check if mask has reasonable coverage
135
+ refined_area = np.sum(refined > 127)
136
+ original_area = np.sum(original > 127)
137
+
138
+ if refined_area == 0:
139
+ return False
140
+
141
+ # Allow some variation but not extreme changes
142
+ ratio = refined_area / max(original_area, 1)
143
+ return 0.5 <= ratio <= 2.0
144
+
145
+ def _process_mask(mask: np.ndarray) -> np.ndarray:
146
+ """Convert any mask format to binary 0/255."""
147
+ if mask.dtype == np.float32 or mask.dtype == np.float64:
148
+ if mask.max() <= 1.0:
149
+ mask = (mask * 255).astype(np.uint8)
150
+
151
+ if mask.dtype != np.uint8:
152
+ mask = mask.astype(np.uint8)
153
+
154
+ if mask.ndim == 3:
155
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
156
+
157
+ _, binary = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
158
+ return binary
159
+
160
+ def _edge_aware_smooth(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
161
+ """Apply edge-aware smoothing using guided filter."""
162
+ # Convert to float for processing
163
+ mask_float = mask.astype(np.float32) / 255.0
164
+
165
+ # Simple guided filter approximation
166
+ radius = 5
167
+ eps = 0.01
168
+
169
+ # Use image as guide
170
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
171
+
172
+ # Box filter for mean
173
+ mean_I = cv2.boxFilter(gray, -1, (radius, radius))
174
+ mean_p = cv2.boxFilter(mask_float, -1, (radius, radius))
175
+ mean_Ip = cv2.boxFilter(gray * mask_float, -1, (radius, radius))
176
+
177
+ # Covariance
178
+ cov_Ip = mean_Ip - mean_I * mean_p
179
+
180
+ # Variance
181
+ mean_II = cv2.boxFilter(gray * gray, -1, (radius, radius))
182
+ var_I = mean_II - mean_I * mean_I
183
+
184
+ # Coefficients
185
+ a = cov_Ip / (var_I + eps)
186
+ b = mean_p - a * mean_I
187
+
188
+ # Filter
189
+ mean_a = cv2.boxFilter(a, -1, (radius, radius))
190
+ mean_b = cv2.boxFilter(b, -1, (radius, radius))
191
+
192
+ refined = mean_a * gray + mean_b
193
+
194
+ # Convert back to binary
195
+ return (refined * 255).clip(0, 255).astype(np.uint8)
196
+
197
+ def _feather_edges(mask: np.ndarray, radius: int = 3) -> np.ndarray:
198
+ """Slightly blur edges for smoother transitions."""
199
+ if radius <= 0:
200
+ return mask
201
+
202
+ # Blur then threshold to maintain binary nature
203
+ blurred = cv2.GaussianBlur(mask, (radius*2+1, radius*2+1), radius/2)
204
+ _, binary = cv2.threshold(blurred, 127, 255, cv2.THRESH_BINARY)
205
+
206
+ return binary
207
+
208
+ def _remove_small_components(mask: np.ndarray, min_area_ratio: float = 0.005) -> np.ndarray:
209
+ """Remove small disconnected components."""
210
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask, connectivity=8)
211
+
212
+ if num_labels <= 1:
213
+ return mask
214
+
215
+ # Calculate minimum area
216
+ total_area = mask.shape[0] * mask.shape[1]
217
+ min_area = int(total_area * min_area_ratio)
218
+
219
+ # Find largest component (excluding background)
220
+ areas = stats[1:, cv2.CC_STAT_AREA]
221
+ if len(areas) == 0:
222
+ return mask
223
+
224
+ max_label = np.argmax(areas) + 1
225
+
226
+ # Keep only components above threshold or the largest one
227
+ cleaned = np.zeros_like(mask)
228
+ for label in range(1, num_labels):
229
+ if stats[label, cv2.CC_STAT_AREA] >= min_area or label == max_label:
230
+ cleaned[labels == label] = 255
231
+
232
+ return cleaned