Spaces:
Sleeping
Sleeping
| """ | |
| Comprehensive Evaluation on Entire TestingSet | |
| Evaluates the complete pipeline on all 30,000 samples from TestingSet | |
| Calculates all metrics: Accuracy, Precision, Recall, F1, IoU, Dice, etc. | |
| No visualizations - metrics only for speed | |
| Usage: | |
| python scripts/evaluate_full_testingset.py | |
| """ | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import json | |
| from datetime import datetime | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.config import get_config | |
| from src.models import get_model | |
| from src.data import get_dataset | |
| from src.features import get_mask_refiner, get_region_extractor | |
| from src.training.classifier import ForgeryClassifier | |
| from src.data.preprocessing import DocumentPreprocessor | |
| from src.data.augmentation import DatasetAwareAugmentation | |
| # Class mapping | |
| CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'} | |
| def calculate_metrics(pred_mask, gt_mask): | |
| """Calculate all segmentation metrics""" | |
| pred = pred_mask.astype(bool) | |
| gt = gt_mask.astype(bool) | |
| intersection = (pred & gt).sum() | |
| union = (pred | gt).sum() | |
| tp = intersection | |
| fp = (pred & ~gt).sum() | |
| fn = (~pred & gt).sum() | |
| tn = (~pred & ~gt).sum() | |
| # Segmentation metrics | |
| iou = intersection / (union + 1e-8) | |
| dice = (2 * intersection) / (pred.sum() + gt.sum() + 1e-8) | |
| precision = tp / (tp + fp + 1e-8) | |
| recall = tp / (tp + fn + 1e-8) | |
| f1 = 2 * precision * recall / (precision + recall + 1e-8) | |
| accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8) | |
| return { | |
| 'iou': float(iou), | |
| 'dice': float(dice), | |
| 'precision': float(precision), | |
| 'recall': float(recall), | |
| 'f1': float(f1), | |
| 'accuracy': float(accuracy), | |
| 'tp': int(tp), | |
| 'fp': int(fp), | |
| 'fn': int(fn), | |
| 'tn': int(tn) | |
| } | |
| def main(): | |
| print("="*80) | |
| print("COMPREHENSIVE EVALUATION ON ENTIRE TESTINGSET") | |
| print("="*80) | |
| print("Dataset: DocTamper TestingSet (30,000 samples)") | |
| print("Mode: Metrics only (no visualizations)") | |
| print("="*80) | |
| config = get_config('config.yaml') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load models | |
| print("\n1. Loading models...") | |
| # Localization model | |
| model = get_model(config).to(device) | |
| checkpoint = torch.load('outputs/checkpoints/best_doctamper.pth', map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| print(f" ✓ Localization model loaded (Dice: {checkpoint.get('best_metric', 0):.2%})") | |
| # Classifier | |
| classifier = ForgeryClassifier(config) | |
| classifier.load('outputs/classifier') | |
| print(f" ✓ Classifier loaded") | |
| # Components | |
| preprocessor = DocumentPreprocessor(config, 'doctamper') | |
| augmentation = DatasetAwareAugmentation(config, 'doctamper', is_training=False) | |
| mask_refiner = get_mask_refiner(config) | |
| region_extractor = get_region_extractor(config) | |
| # Load dataset | |
| print("\n2. Loading TestingSet...") | |
| dataset = get_dataset(config, 'doctamper', split='val') # val = TestingSet | |
| total_samples = len(dataset) | |
| print(f" ✓ Loaded {total_samples} samples") | |
| # Initialize metrics storage | |
| all_metrics = [] | |
| detection_stats = { | |
| 'total': 0, | |
| 'has_forgery': 0, | |
| 'detected': 0, | |
| 'missed': 0, | |
| 'false_positives': 0, | |
| 'true_negatives': 0 | |
| } | |
| print("\n3. Running evaluation...") | |
| print("="*80) | |
| # Process all samples | |
| for idx in tqdm(range(total_samples), desc="Evaluating"): | |
| try: | |
| # Get sample from dataset | |
| image_tensor, mask_tensor, metadata = dataset[idx] | |
| # Ground truth | |
| gt_mask = mask_tensor.numpy()[0] | |
| gt_mask_binary = (gt_mask > 0.5).astype(np.uint8) | |
| has_forgery = gt_mask_binary.sum() > 0 | |
| # Run localization | |
| with torch.no_grad(): | |
| image_batch = image_tensor.unsqueeze(0).to(device) | |
| logits, _ = model(image_batch) | |
| prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0] | |
| # Generate mask | |
| binary_mask = (prob_map > 0.5).astype(np.uint8) | |
| refined_mask = mask_refiner.refine(binary_mask) | |
| # Calculate metrics | |
| metrics = calculate_metrics(refined_mask, gt_mask_binary) | |
| metrics['sample_idx'] = idx | |
| metrics['has_forgery'] = has_forgery | |
| metrics['prob_max'] = float(prob_map.max()) | |
| # Detection statistics | |
| detected = refined_mask.sum() > 0 | |
| detection_stats['total'] += 1 | |
| if has_forgery: | |
| detection_stats['has_forgery'] += 1 | |
| if detected: | |
| detection_stats['detected'] += 1 | |
| else: | |
| detection_stats['missed'] += 1 | |
| else: | |
| if detected: | |
| detection_stats['false_positives'] += 1 | |
| else: | |
| detection_stats['true_negatives'] += 1 | |
| all_metrics.append(metrics) | |
| except Exception as e: | |
| print(f"\nError at sample {idx}: {str(e)[:100]}") | |
| continue | |
| # Calculate overall statistics | |
| print("\n" + "="*80) | |
| print("RESULTS") | |
| print("="*80) | |
| # Detection statistics | |
| print("\n📊 DETECTION STATISTICS:") | |
| print("-"*80) | |
| print(f"Total samples: {detection_stats['total']}") | |
| print(f"Samples with forgery: {detection_stats['has_forgery']}") | |
| print(f"Samples without forgery: {detection_stats['total'] - detection_stats['has_forgery']}") | |
| print() | |
| print(f"✅ Correctly detected: {detection_stats['detected']}") | |
| print(f"❌ Missed detections: {detection_stats['missed']}") | |
| print(f"⚠️ False positives: {detection_stats['false_positives']}") | |
| print(f"✓ True negatives: {detection_stats['true_negatives']}") | |
| print() | |
| # Detection rates | |
| if detection_stats['has_forgery'] > 0: | |
| detection_rate = detection_stats['detected'] / detection_stats['has_forgery'] | |
| miss_rate = detection_stats['missed'] / detection_stats['has_forgery'] | |
| print(f"Detection Rate (Recall): {detection_rate:.2%} ⬆️ Higher is better") | |
| print(f"Miss Rate: {miss_rate:.2%} ⬇️ Lower is better") | |
| if detection_stats['detected'] + detection_stats['false_positives'] > 0: | |
| precision = detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives']) | |
| print(f"Precision: {precision:.2%} ⬆️ Higher is better") | |
| overall_accuracy = (detection_stats['detected'] + detection_stats['true_negatives']) / detection_stats['total'] | |
| print(f"Overall Accuracy: {overall_accuracy:.2%} ⬆️ Higher is better") | |
| # Segmentation metrics (only for samples with forgery) | |
| forgery_metrics = [m for m in all_metrics if m['has_forgery']] | |
| if forgery_metrics: | |
| print("\n📈 SEGMENTATION METRICS (on samples with forgery):") | |
| print("-"*80) | |
| avg_iou = np.mean([m['iou'] for m in forgery_metrics]) | |
| avg_dice = np.mean([m['dice'] for m in forgery_metrics]) | |
| avg_precision = np.mean([m['precision'] for m in forgery_metrics]) | |
| avg_recall = np.mean([m['recall'] for m in forgery_metrics]) | |
| avg_f1 = np.mean([m['f1'] for m in forgery_metrics]) | |
| avg_accuracy = np.mean([m['accuracy'] for m in forgery_metrics]) | |
| print(f"IoU (Intersection over Union): {avg_iou:.4f} ⬆️ Higher is better (0-1)") | |
| print(f"Dice Coefficient: {avg_dice:.4f} ⬆️ Higher is better (0-1)") | |
| print(f"Pixel Precision: {avg_precision:.4f} ⬆️ Higher is better (0-1)") | |
| print(f"Pixel Recall: {avg_recall:.4f} ⬆️ Higher is better (0-1)") | |
| print(f"Pixel F1-Score: {avg_f1:.4f} ⬆️ Higher is better (0-1)") | |
| print(f"Pixel Accuracy: {avg_accuracy:.4f} ⬆️ Higher is better (0-1)") | |
| # Probability statistics | |
| avg_prob = np.mean([m['prob_max'] for m in forgery_metrics]) | |
| print(f"\nAverage Max Probability: {avg_prob:.4f}") | |
| # Save results | |
| print("\n" + "="*80) | |
| print("SAVING RESULTS") | |
| print("="*80) | |
| output_dir = Path('outputs/evaluation') | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Summary | |
| summary = { | |
| 'timestamp': datetime.now().isoformat(), | |
| 'total_samples': detection_stats['total'], | |
| 'detection_statistics': detection_stats, | |
| 'detection_rate': detection_stats['detected'] / detection_stats['has_forgery'] if detection_stats['has_forgery'] > 0 else 0, | |
| 'precision': detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives']) if (detection_stats['detected'] + detection_stats['false_positives']) > 0 else 0, | |
| 'overall_accuracy': overall_accuracy, | |
| 'segmentation_metrics': { | |
| 'iou': float(avg_iou) if forgery_metrics else 0, | |
| 'dice': float(avg_dice) if forgery_metrics else 0, | |
| 'precision': float(avg_precision) if forgery_metrics else 0, | |
| 'recall': float(avg_recall) if forgery_metrics else 0, | |
| 'f1': float(avg_f1) if forgery_metrics else 0, | |
| 'accuracy': float(avg_accuracy) if forgery_metrics else 0 | |
| } | |
| } | |
| summary_path = output_dir / 'evaluation_summary.json' | |
| with open(summary_path, 'w') as f: | |
| json.dump(summary, f, indent=2) | |
| print(f"✓ Summary saved to: {summary_path}") | |
| # Detailed metrics (optional - can be large) | |
| # detailed_path = output_dir / 'detailed_metrics.json' | |
| # with open(detailed_path, 'w') as f: | |
| # json.dump(all_metrics, f, indent=2) | |
| # print(f"✓ Detailed metrics saved to: {detailed_path}") | |
| print("\n" + "="*80) | |
| print("✅ EVALUATION COMPLETE!") | |
| print("="*80) | |
| print(f"\nKey Metrics Summary:") | |
| print(f" Detection Rate: {detection_stats['detected'] / detection_stats['has_forgery']:.2%}") | |
| print(f" Overall Accuracy: {overall_accuracy:.2%}") | |
| print(f" Dice Score: {avg_dice:.4f}" if forgery_metrics else " Dice Score: N/A") | |
| print(f" IoU: {avg_iou:.4f}" if forgery_metrics else " IoU: N/A") | |
| print("="*80 + "\n") | |
| if __name__ == '__main__': | |
| main() | |