MogensR commited on
Commit
9de3cd4
·
1 Parent(s): 9015f7f

Create core/quality.py

Browse files
Files changed (1) hide show
  1. core/quality.py +409 -0
core/quality.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quality analysis and metrics for BackgroundFX Pro.
3
+ Provides REAL metrics instead of fake 100% values.
4
+ """
5
+
6
+ import numpy as np
7
+ import cv2
8
+ import torch
9
+ from typing import Dict, List, Optional, Tuple, Any
10
+ from dataclasses import dataclass, field
11
+ from collections import deque
12
+ import logging
13
+ from scipy import signal, ndimage
14
+ from skimage import metrics as skmetrics
15
+ import json
16
+ from pathlib import Path
17
+ from datetime import datetime
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class QualityMetrics:
24
+ """Real quality metrics container."""
25
+ # Edge Quality
26
+ edge_accuracy: float = 0.0
27
+ edge_smoothness: float = 0.0
28
+ edge_completeness: float = 0.0
29
+
30
+ # Temporal Quality
31
+ temporal_stability: float = 0.0
32
+ temporal_consistency: float = 0.0
33
+ flicker_score: float = 0.0
34
+
35
+ # Mask Quality
36
+ mask_coverage: float = 0.0
37
+ mask_accuracy: float = 0.0
38
+ mask_confidence: float = 0.0
39
+ hole_ratio: float = 0.0
40
+
41
+ # Detail Preservation
42
+ detail_preservation: float = 0.0
43
+ hair_detail_score: float = 0.0
44
+ texture_quality: float = 0.0
45
+
46
+ # Overall Scores
47
+ overall_quality: float = 0.0
48
+ processing_confidence: float = 0.0
49
+
50
+ # Detailed breakdown
51
+ breakdown: Dict[str, float] = field(default_factory=dict)
52
+ warnings: List[str] = field(default_factory=list)
53
+
54
+ def to_dict(self) -> Dict[str, Any]:
55
+ """Convert to dictionary."""
56
+ return {
57
+ 'edge_accuracy': round(self.edge_accuracy, 3),
58
+ 'edge_smoothness': round(self.edge_smoothness, 3),
59
+ 'edge_completeness': round(self.edge_completeness, 3),
60
+ 'temporal_stability': round(self.temporal_stability, 3),
61
+ 'temporal_consistency': round(self.temporal_consistency, 3),
62
+ 'flicker_score': round(self.flicker_score, 3),
63
+ 'mask_coverage': round(self.mask_coverage, 3),
64
+ 'mask_accuracy': round(self.mask_accuracy, 3),
65
+ 'mask_confidence': round(self.mask_confidence, 3),
66
+ 'hole_ratio': round(self.hole_ratio, 3),
67
+ 'detail_preservation': round(self.detail_preservation, 3),
68
+ 'hair_detail_score': round(self.hair_detail_score, 3),
69
+ 'texture_quality': round(self.texture_quality, 3),
70
+ 'overall_quality': round(self.overall_quality, 3),
71
+ 'processing_confidence': round(self.processing_confidence, 3),
72
+ 'breakdown': self.breakdown,
73
+ 'warnings': self.warnings
74
+ }
75
+
76
+ def get_summary(self) -> str:
77
+ """Get human-readable summary."""
78
+ status = "Excellent" if self.overall_quality > 0.9 else \
79
+ "Good" if self.overall_quality > 0.75 else \
80
+ "Fair" if self.overall_quality > 0.6 else "Poor"
81
+
82
+ return (f"Quality: {status} ({self.overall_quality:.1%})\n"
83
+ f"Edge: {self.edge_accuracy:.1%} | "
84
+ f"Temporal: {self.temporal_stability:.1%} | "
85
+ f"Detail: {self.detail_preservation:.1%}")
86
+
87
+
88
+ @dataclass
89
+ class QualityConfig:
90
+ """Configuration for quality analysis."""
91
+ enable_deep_analysis: bool = True
92
+ temporal_window: int = 5
93
+ edge_threshold: float = 0.1
94
+ min_confidence: float = 0.6
95
+ detect_artifacts: bool = True
96
+ compute_ssim: bool = True
97
+ compute_psnr: bool = True
98
+ save_reports: bool = True
99
+ report_dir: str = "LOGS/quality_reports"
100
+ warning_thresholds: Dict[str, float] = field(default_factory=lambda: {
101
+ 'edge_accuracy': 0.7,
102
+ 'temporal_stability': 0.75,
103
+ 'mask_accuracy': 0.8,
104
+ 'detail_preservation': 0.7
105
+ })
106
+
107
+
108
+ class QualityAnalyzer:
109
+ """Comprehensive quality analysis system."""
110
+
111
+ def __init__(self, config: Optional[QualityConfig] = None):
112
+ self.config = config or QualityConfig()
113
+ self.frame_buffer = deque(maxlen=self.config.temporal_window)
114
+ self.mask_buffer = deque(maxlen=self.config.temporal_window)
115
+ self.metrics_history = deque(maxlen=100)
116
+ self.frame_count = 0
117
+
118
+ # Initialize analyzers
119
+ self.edge_analyzer = EdgeQualityAnalyzer()
120
+ self.temporal_analyzer = TemporalQualityAnalyzer()
121
+ self.detail_analyzer = DetailPreservationAnalyzer()
122
+ self.artifact_detector = ArtifactDetector()
123
+
124
+ # Create report directory
125
+ if self.config.save_reports:
126
+ Path(self.config.report_dir).mkdir(parents=True, exist_ok=True)
127
+
128
+ def analyze_frame(self,
129
+ original_frame: np.ndarray,
130
+ processed_frame: np.ndarray,
131
+ mask: np.ndarray,
132
+ alpha: Optional[np.ndarray] = None) -> QualityMetrics:
133
+ """Analyze frame quality with REAL metrics."""
134
+ self.frame_count += 1
135
+ metrics = QualityMetrics()
136
+
137
+ # Add to buffers
138
+ self.frame_buffer.append(processed_frame)
139
+ self.mask_buffer.append(mask)
140
+
141
+ # 1. Edge Quality Analysis
142
+ edge_metrics = self.edge_analyzer.analyze(original_frame, mask, alpha)
143
+ metrics.edge_accuracy = edge_metrics['accuracy']
144
+ metrics.edge_smoothness = edge_metrics['smoothness']
145
+ metrics.edge_completeness = edge_metrics['completeness']
146
+
147
+ # 2. Temporal Quality (if we have history)
148
+ if len(self.mask_buffer) >= 2:
149
+ temporal_metrics = self.temporal_analyzer.analyze(
150
+ self.mask_buffer, self.frame_buffer
151
+ )
152
+ metrics.temporal_stability = temporal_metrics['stability']
153
+ metrics.temporal_consistency = temporal_metrics['consistency']
154
+ metrics.flicker_score = temporal_metrics['flicker']
155
+ else:
156
+ # First frame defaults
157
+ metrics.temporal_stability = 1.0
158
+ metrics.temporal_consistency = 1.0
159
+ metrics.flicker_score = 0.0
160
+
161
+ # 3. Mask Quality Analysis
162
+ mask_metrics = self._analyze_mask_quality(mask, alpha)
163
+ metrics.mask_coverage = mask_metrics['coverage']
164
+ metrics.mask_accuracy = mask_metrics['accuracy']
165
+ metrics.mask_confidence = mask_metrics['confidence']
166
+ metrics.hole_ratio = mask_metrics['hole_ratio']
167
+
168
+ # 4. Detail Preservation
169
+ detail_metrics = self.detail_analyzer.analyze(
170
+ original_frame, processed_frame, mask
171
+ )
172
+ metrics.detail_preservation = detail_metrics['overall']
173
+ metrics.hair_detail_score = detail_metrics['hair_detail']
174
+ metrics.texture_quality = detail_metrics['texture']
175
+
176
+ # 5. Artifact Detection
177
+ if self.config.detect_artifacts:
178
+ artifacts = self.artifact_detector.detect(processed_frame, mask)
179
+ if artifacts['found']:
180
+ for artifact in artifacts['types']:
181
+ metrics.warnings.append(f"Artifact detected: {artifact}")
182
+
183
+ # 6. Compute Overall Quality (weighted average)
184
+ metrics.overall_quality = self._compute_overall_quality(metrics)
185
+ metrics.processing_confidence = self._compute_confidence(metrics)
186
+
187
+ # 7. Generate warnings based on thresholds
188
+ self._generate_warnings(metrics)
189
+
190
+ # 8. Store in history
191
+ self.metrics_history.append(metrics)
192
+
193
+ # 9. Save report if configured
194
+ if self.config.save_reports and self.frame_count % 30 == 0:
195
+ self._save_report(metrics)
196
+
197
+ return metrics
198
+
199
+ def _analyze_mask_quality(self, mask: np.ndarray,
200
+ alpha: Optional[np.ndarray] = None) -> Dict[str, float]:
201
+ """Analyze mask quality metrics."""
202
+ h, w = mask.shape[:2]
203
+ total_pixels = h * w
204
+
205
+ # Coverage ratio
206
+ coverage = np.sum(mask > 0.5) / total_pixels
207
+
208
+ # Hole detection
209
+ mask_binary = (mask > 0.5).astype(np.uint8)
210
+
211
+ # Find contours
212
+ contours, _ = cv2.findContours(
213
+ mask_binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
214
+ )
215
+
216
+ # Find holes (internal contours)
217
+ hole_area = 0
218
+ if len(contours) > 0:
219
+ # Create filled mask
220
+ filled = np.zeros_like(mask_binary)
221
+ cv2.drawContours(filled, contours, -1, 1, -1)
222
+
223
+ # Holes are the difference
224
+ holes = filled - mask_binary
225
+ hole_area = np.sum(holes) / np.sum(filled) if np.sum(filled) > 0 else 0
226
+
227
+ # Accuracy (based on gradient consistency)
228
+ gradient_x = cv2.Sobel(mask, cv2.CV_64F, 1, 0, ksize=3)
229
+ gradient_y = cv2.Sobel(mask, cv2.CV_64F, 0, 1, ksize=3)
230
+ gradient_mag = np.sqrt(gradient_x**2 + gradient_y**2)
231
+
232
+ # Good masks have smooth gradients
233
+ gradient_smoothness = 1.0 - np.std(gradient_mag) / (np.mean(gradient_mag) + 1e-6)
234
+ accuracy = np.clip(gradient_smoothness, 0, 1)
235
+
236
+ # Confidence (alpha vs mask consistency if alpha provided)
237
+ if alpha is not None:
238
+ diff = np.abs(alpha - mask)
239
+ confidence = 1.0 - np.mean(diff)
240
+ else:
241
+ # Use mask value distribution as confidence
242
+ hist, _ = np.histogram(mask.flatten(), bins=10, range=(0, 1))
243
+ hist = hist / hist.sum()
244
+ # High confidence = values clustered near 0 or 1
245
+ confidence = (hist[0] + hist[-1]) / 2.0
246
+
247
+ return {
248
+ 'coverage': coverage,
249
+ 'accuracy': accuracy,
250
+ 'confidence': confidence,
251
+ 'hole_ratio': hole_area
252
+ }
253
+
254
+ def _compute_overall_quality(self, metrics: QualityMetrics) -> float:
255
+ """Compute weighted overall quality score."""
256
+ weights = {
257
+ 'edge': 0.25,
258
+ 'temporal': 0.25,
259
+ 'mask': 0.25,
260
+ 'detail': 0.25
261
+ }
262
+
263
+ # Component scores
264
+ edge_score = np.mean([
265
+ metrics.edge_accuracy,
266
+ metrics.edge_smoothness,
267
+ metrics.edge_completeness
268
+ ])
269
+
270
+ temporal_score = np.mean([
271
+ metrics.temporal_stability,
272
+ metrics.temporal_consistency,
273
+ 1.0 - metrics.flicker_score # Invert flicker
274
+ ])
275
+
276
+ mask_score = np.mean([
277
+ metrics.mask_accuracy,
278
+ metrics.mask_confidence,
279
+ 1.0 - metrics.hole_ratio # Invert hole ratio
280
+ ])
281
+
282
+ detail_score = np.mean([
283
+ metrics.detail_preservation,
284
+ metrics.hair_detail_score,
285
+ metrics.texture_quality
286
+ ])
287
+
288
+ # Weighted average
289
+ overall = (
290
+ weights['edge'] * edge_score +
291
+ weights['temporal'] * temporal_score +
292
+ weights['mask'] * mask_score +
293
+ weights['detail'] * detail_score
294
+ )
295
+
296
+ # Apply penalties for warnings
297
+ penalty = len(metrics.warnings) * 0.05
298
+ overall = max(0, overall - penalty)
299
+
300
+ return np.clip(overall, 0, 1)
301
+
302
+ def _compute_confidence(self, metrics: QualityMetrics) -> float:
303
+ """Compute processing confidence."""
304
+ # Factors that affect confidence
305
+ factors = []
306
+
307
+ # High edge accuracy increases confidence
308
+ factors.append(metrics.edge_accuracy)
309
+
310
+ # Good temporal stability increases confidence
311
+ factors.append(metrics.temporal_stability)
312
+
313
+ # Low hole ratio increases confidence
314
+ factors.append(1.0 - metrics.hole_ratio)
315
+
316
+ # Mask confidence directly affects overall confidence
317
+ factors.append(metrics.mask_confidence)
318
+
319
+ # No warnings increases confidence
320
+ warning_factor = 1.0 if len(metrics.warnings) == 0 else 0.8
321
+ factors.append(warning_factor)
322
+
323
+ return np.mean(factors)
324
+
325
+ def _generate_warnings(self, metrics: QualityMetrics):
326
+ """Generate warnings based on quality thresholds."""
327
+ for metric_name, threshold in self.config.warning_thresholds.items():
328
+ if hasattr(metrics, metric_name):
329
+ value = getattr(metrics, metric_name)
330
+ if value < threshold:
331
+ metrics.warnings.append(
332
+ f"Low {metric_name.replace('_', ' ')}: {value:.1%} < {threshold:.1%}"
333
+ )
334
+
335
+ def _save_report(self, metrics: QualityMetrics):
336
+ """Save quality report to file."""
337
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
338
+ report_path = Path(self.config.report_dir) / f"quality_report_{timestamp}.json"
339
+
340
+ report = {
341
+ 'timestamp': timestamp,
342
+ 'frame_count': self.frame_count,
343
+ 'metrics': metrics.to_dict(),
344
+ 'config': {
345
+ 'temporal_window': self.config.temporal_window,
346
+ 'edge_threshold': self.config.edge_threshold,
347
+ 'min_confidence': self.config.min_confidence
348
+ }
349
+ }
350
+
351
+ with open(report_path, 'w') as f:
352
+ json.dump(report, f, indent=2)
353
+
354
+ logger.info(f"Quality report saved to {report_path}")
355
+
356
+ def get_statistics(self) -> Dict[str, Any]:
357
+ """Get quality statistics over time."""
358
+ if not self.metrics_history:
359
+ return {}
360
+
361
+ # Compute statistics
362
+ all_metrics = list(self.metrics_history)
363
+
364
+ stats = {
365
+ 'average_quality': np.mean([m.overall_quality for m in all_metrics]),
366
+ 'min_quality': np.min([m.overall_quality for m in all_metrics]),
367
+ 'max_quality': np.max([m.overall_quality for m in all_metrics]),
368
+ 'std_quality': np.std([m.overall_quality for m in all_metrics]),
369
+ 'total_warnings': sum(len(m.warnings) for m in all_metrics),
370
+ 'frames_analyzed': len(all_metrics)
371
+ }
372
+
373
+ return stats
374
+
375
+
376
+ class EdgeQualityAnalyzer:
377
+ """Analyzes edge quality in masks."""
378
+
379
+ def analyze(self, image: np.ndarray, mask: np.ndarray,
380
+ alpha: Optional[np.ndarray] = None) -> Dict[str, float]:
381
+ """Analyze edge quality metrics."""
382
+ # Convert to grayscale if needed
383
+ if len(image.shape) == 3:
384
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
385
+ else:
386
+ gray = image
387
+
388
+ # Detect edges in image
389
+ image_edges = cv2.Canny(gray, 50, 150) / 255.0
390
+
391
+ # Detect edges in mask
392
+ mask_uint8 = (mask * 255).astype(np.uint8)
393
+ mask_edges = cv2.Canny(mask_uint8, 50, 150) / 255.0
394
+
395
+ # Edge accuracy: how well mask edges align with image edges
396
+ overlap = np.logical_and(image_edges > 0, mask_edges > 0)
397
+ accuracy = np.sum(overlap) / (np.sum(mask_edges) + 1e-6)
398
+
399
+ # Edge smoothness: measure edge roughness
400
+ contours, _ = cv2.findContours(
401
+ mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
402
+ )
403
+
404
+ smoothness = 1.0
405
+ if len(contours) > 0:
406
+ # Approximate contours and measure approximation quality
407
+ for contour in contours:
408
+ perimeter = cv2.arcLength(contour, True)
409
+ if perimeter