Document_Forgery_Detection / scripts /evaluate_full_testingset.py
JKrishnanandhaa's picture
Upload 8 files
51fdac5 verified
"""
Comprehensive Evaluation on Entire TestingSet
Evaluates the complete pipeline on all 30,000 samples from TestingSet
Calculates all metrics: Accuracy, Precision, Recall, F1, IoU, Dice, etc.
No visualizations - metrics only for speed
Usage:
python scripts/evaluate_full_testingset.py
"""
import sys
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
import json
from datetime import datetime
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.config import get_config
from src.models import get_model
from src.data import get_dataset
from src.features import get_mask_refiner, get_region_extractor
from src.training.classifier import ForgeryClassifier
from src.data.preprocessing import DocumentPreprocessor
from src.data.augmentation import DatasetAwareAugmentation
# Class mapping
CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
def calculate_metrics(pred_mask, gt_mask):
"""Calculate all segmentation metrics"""
pred = pred_mask.astype(bool)
gt = gt_mask.astype(bool)
intersection = (pred & gt).sum()
union = (pred | gt).sum()
tp = intersection
fp = (pred & ~gt).sum()
fn = (~pred & gt).sum()
tn = (~pred & ~gt).sum()
# Segmentation metrics
iou = intersection / (union + 1e-8)
dice = (2 * intersection) / (pred.sum() + gt.sum() + 1e-8)
precision = tp / (tp + fp + 1e-8)
recall = tp / (tp + fn + 1e-8)
f1 = 2 * precision * recall / (precision + recall + 1e-8)
accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-8)
return {
'iou': float(iou),
'dice': float(dice),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'accuracy': float(accuracy),
'tp': int(tp),
'fp': int(fp),
'fn': int(fn),
'tn': int(tn)
}
def main():
print("="*80)
print("COMPREHENSIVE EVALUATION ON ENTIRE TESTINGSET")
print("="*80)
print("Dataset: DocTamper TestingSet (30,000 samples)")
print("Mode: Metrics only (no visualizations)")
print("="*80)
config = get_config('config.yaml')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load models
print("\n1. Loading models...")
# Localization model
model = get_model(config).to(device)
checkpoint = torch.load('outputs/checkpoints/best_doctamper.pth', map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(f" ✓ Localization model loaded (Dice: {checkpoint.get('best_metric', 0):.2%})")
# Classifier
classifier = ForgeryClassifier(config)
classifier.load('outputs/classifier')
print(f" ✓ Classifier loaded")
# Components
preprocessor = DocumentPreprocessor(config, 'doctamper')
augmentation = DatasetAwareAugmentation(config, 'doctamper', is_training=False)
mask_refiner = get_mask_refiner(config)
region_extractor = get_region_extractor(config)
# Load dataset
print("\n2. Loading TestingSet...")
dataset = get_dataset(config, 'doctamper', split='val') # val = TestingSet
total_samples = len(dataset)
print(f" ✓ Loaded {total_samples} samples")
# Initialize metrics storage
all_metrics = []
detection_stats = {
'total': 0,
'has_forgery': 0,
'detected': 0,
'missed': 0,
'false_positives': 0,
'true_negatives': 0
}
print("\n3. Running evaluation...")
print("="*80)
# Process all samples
for idx in tqdm(range(total_samples), desc="Evaluating"):
try:
# Get sample from dataset
image_tensor, mask_tensor, metadata = dataset[idx]
# Ground truth
gt_mask = mask_tensor.numpy()[0]
gt_mask_binary = (gt_mask > 0.5).astype(np.uint8)
has_forgery = gt_mask_binary.sum() > 0
# Run localization
with torch.no_grad():
image_batch = image_tensor.unsqueeze(0).to(device)
logits, _ = model(image_batch)
prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
# Generate mask
binary_mask = (prob_map > 0.5).astype(np.uint8)
refined_mask = mask_refiner.refine(binary_mask)
# Calculate metrics
metrics = calculate_metrics(refined_mask, gt_mask_binary)
metrics['sample_idx'] = idx
metrics['has_forgery'] = has_forgery
metrics['prob_max'] = float(prob_map.max())
# Detection statistics
detected = refined_mask.sum() > 0
detection_stats['total'] += 1
if has_forgery:
detection_stats['has_forgery'] += 1
if detected:
detection_stats['detected'] += 1
else:
detection_stats['missed'] += 1
else:
if detected:
detection_stats['false_positives'] += 1
else:
detection_stats['true_negatives'] += 1
all_metrics.append(metrics)
except Exception as e:
print(f"\nError at sample {idx}: {str(e)[:100]}")
continue
# Calculate overall statistics
print("\n" + "="*80)
print("RESULTS")
print("="*80)
# Detection statistics
print("\n📊 DETECTION STATISTICS:")
print("-"*80)
print(f"Total samples: {detection_stats['total']}")
print(f"Samples with forgery: {detection_stats['has_forgery']}")
print(f"Samples without forgery: {detection_stats['total'] - detection_stats['has_forgery']}")
print()
print(f"✅ Correctly detected: {detection_stats['detected']}")
print(f"❌ Missed detections: {detection_stats['missed']}")
print(f"⚠️ False positives: {detection_stats['false_positives']}")
print(f"✓ True negatives: {detection_stats['true_negatives']}")
print()
# Detection rates
if detection_stats['has_forgery'] > 0:
detection_rate = detection_stats['detected'] / detection_stats['has_forgery']
miss_rate = detection_stats['missed'] / detection_stats['has_forgery']
print(f"Detection Rate (Recall): {detection_rate:.2%} ⬆️ Higher is better")
print(f"Miss Rate: {miss_rate:.2%} ⬇️ Lower is better")
if detection_stats['detected'] + detection_stats['false_positives'] > 0:
precision = detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives'])
print(f"Precision: {precision:.2%} ⬆️ Higher is better")
overall_accuracy = (detection_stats['detected'] + detection_stats['true_negatives']) / detection_stats['total']
print(f"Overall Accuracy: {overall_accuracy:.2%} ⬆️ Higher is better")
# Segmentation metrics (only for samples with forgery)
forgery_metrics = [m for m in all_metrics if m['has_forgery']]
if forgery_metrics:
print("\n📈 SEGMENTATION METRICS (on samples with forgery):")
print("-"*80)
avg_iou = np.mean([m['iou'] for m in forgery_metrics])
avg_dice = np.mean([m['dice'] for m in forgery_metrics])
avg_precision = np.mean([m['precision'] for m in forgery_metrics])
avg_recall = np.mean([m['recall'] for m in forgery_metrics])
avg_f1 = np.mean([m['f1'] for m in forgery_metrics])
avg_accuracy = np.mean([m['accuracy'] for m in forgery_metrics])
print(f"IoU (Intersection over Union): {avg_iou:.4f} ⬆️ Higher is better (0-1)")
print(f"Dice Coefficient: {avg_dice:.4f} ⬆️ Higher is better (0-1)")
print(f"Pixel Precision: {avg_precision:.4f} ⬆️ Higher is better (0-1)")
print(f"Pixel Recall: {avg_recall:.4f} ⬆️ Higher is better (0-1)")
print(f"Pixel F1-Score: {avg_f1:.4f} ⬆️ Higher is better (0-1)")
print(f"Pixel Accuracy: {avg_accuracy:.4f} ⬆️ Higher is better (0-1)")
# Probability statistics
avg_prob = np.mean([m['prob_max'] for m in forgery_metrics])
print(f"\nAverage Max Probability: {avg_prob:.4f}")
# Save results
print("\n" + "="*80)
print("SAVING RESULTS")
print("="*80)
output_dir = Path('outputs/evaluation')
output_dir.mkdir(parents=True, exist_ok=True)
# Summary
summary = {
'timestamp': datetime.now().isoformat(),
'total_samples': detection_stats['total'],
'detection_statistics': detection_stats,
'detection_rate': detection_stats['detected'] / detection_stats['has_forgery'] if detection_stats['has_forgery'] > 0 else 0,
'precision': detection_stats['detected'] / (detection_stats['detected'] + detection_stats['false_positives']) if (detection_stats['detected'] + detection_stats['false_positives']) > 0 else 0,
'overall_accuracy': overall_accuracy,
'segmentation_metrics': {
'iou': float(avg_iou) if forgery_metrics else 0,
'dice': float(avg_dice) if forgery_metrics else 0,
'precision': float(avg_precision) if forgery_metrics else 0,
'recall': float(avg_recall) if forgery_metrics else 0,
'f1': float(avg_f1) if forgery_metrics else 0,
'accuracy': float(avg_accuracy) if forgery_metrics else 0
}
}
summary_path = output_dir / 'evaluation_summary.json'
with open(summary_path, 'w') as f:
json.dump(summary, f, indent=2)
print(f"✓ Summary saved to: {summary_path}")
# Detailed metrics (optional - can be large)
# detailed_path = output_dir / 'detailed_metrics.json'
# with open(detailed_path, 'w') as f:
# json.dump(all_metrics, f, indent=2)
# print(f"✓ Detailed metrics saved to: {detailed_path}")
print("\n" + "="*80)
print("✅ EVALUATION COMPLETE!")
print("="*80)
print(f"\nKey Metrics Summary:")
print(f" Detection Rate: {detection_stats['detected'] / detection_stats['has_forgery']:.2%}")
print(f" Overall Accuracy: {overall_accuracy:.2%}")
print(f" Dice Score: {avg_dice:.4f}" if forgery_metrics else " Dice Score: N/A")
print(f" IoU: {avg_iou:.4f}" if forgery_metrics else " IoU: N/A")
print("="*80 + "\n")
if __name__ == '__main__':
main()