JKrishnanandhaa commited on
Commit
3870592
·
verified ·
1 Parent(s): f50c0c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -530
app.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
- Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
-
4
- This app provides a web interface for detecting and classifying document forgeries.
5
  """
6
 
7
  import gradio as gr
@@ -9,13 +8,14 @@ import torch
9
  import cv2
10
  import numpy as np
11
  from PIL import Image
12
- import json
13
  from pathlib import Path
14
  import sys
15
- from typing import Dict, List, Tuple
16
- import plotly.graph_objects as go
17
 
18
- # Add src to path
 
 
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
@@ -26,568 +26,181 @@ from src.features.region_extraction import get_mask_refiner, get_region_extracto
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
- # Class names
30
- CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'}
 
 
31
  CLASS_COLORS = {
32
- 0: (217, 83, 79), # #d9534f - Muted red
33
- 1: (92, 184, 92), # #5cb85c - Muted green
34
- 2: (65, 105, 225) # #4169E1 - Royal blue
35
  }
36
 
37
- # Actual model performance metrics
38
- MODEL_METRICS = {
39
- 'segmentation': {
40
- 'dice': 0.6212,
41
- 'iou': 0.4506,
42
- 'precision': 0.7077,
43
- 'recall': 0.5536
44
- },
45
- 'classification': {
46
- 'overall_accuracy': 0.8897,
47
- 'per_class': {
48
- 'copy_move': 0.92,
49
- 'splicing': 0.85,
50
- 'generation': 0.90
51
- }
52
- }
53
- }
54
-
55
-
56
- def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure:
57
- """Create a subtle radial gauge chart"""
58
- fig = go.Figure(go.Indicator(
59
- mode="gauge+number",
60
- value=value * 100,
61
- domain={'x': [0, 1], 'y': [0, 1]},
62
- title={'text': title, 'font': {'size': 14}},
63
- number={'suffix': '%', 'font': {'size': 24}},
64
- gauge={
65
- 'axis': {'range': [0, 100], 'tickwidth': 1},
66
- 'bar': {'color': '#4169E1', 'thickness': 0.7},
67
- 'bgcolor': 'rgba(0,0,0,0)',
68
- 'borderwidth': 0,
69
- 'steps': [
70
- {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'},
71
- {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'},
72
- {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'}
73
- ]
74
- }
75
- ))
76
-
77
- fig.update_layout(
78
- paper_bgcolor='rgba(0,0,0,0)',
79
- plot_bgcolor='rgba(0,0,0,0)',
80
- height=200,
81
- margin=dict(l=20, r=20, t=40, b=20)
82
- )
83
-
84
- return fig
85
-
86
-
87
- def create_detection_metrics_gauge(avg_confidence: float, iou: float, precision: float, recall: float, num_detections: int) -> go.Figure:
88
- """Create a high-fidelity radial bar chart (concentric rings)"""
89
-
90
- # Calculate percentages (0-100)
91
- metrics = [
92
- {'name': 'Confidence', 'val': avg_confidence * 100 if num_detections > 0 else 0, 'color': '#4169E1', 'base': 80},
93
- {'name': 'Precision', 'val': precision * 100, 'color': '#5cb85c', 'base': 60},
94
- {'name': 'Recall', 'val': recall * 100, 'color': '#f0ad4e', 'base': 40},
95
- {'name': 'IoU', 'val': iou * 100, 'color': '#d9534f', 'base': 20}
96
- ]
97
-
98
- fig = go.Figure()
99
-
100
- for m in metrics:
101
- # 1. Add background track (faint gray ring)
102
- fig.add_trace(go.Barpolar(
103
- r=[15],
104
- theta=[180],
105
- width=[360],
106
- base=m['base'],
107
- marker_color='rgba(128,128,128,0.1)',
108
- hoverinfo='none',
109
- showlegend=False
110
- ))
111
-
112
- # 2. Add the actual metric bar (the colored arc)
113
- # 100% = 360 degrees
114
- angle_width = m['val'] * 3.6
115
- fig.add_trace(go.Barpolar(
116
- r=[15],
117
- theta=[angle_width / 2],
118
- width=[angle_width],
119
- base=m['base'],
120
- name=f"{m['name']}: {m['val']:.1f}%",
121
- marker_color=m['color'],
122
- marker_line_width=0,
123
- hoverinfo='name'
124
- ))
125
-
126
- fig.update_layout(
127
- polar=dict(
128
- hole=0.1,
129
- radialaxis=dict(visible=False, range=[0, 100]),
130
- angularaxis=dict(
131
- rotation=90, # Start at 12 o'clock
132
- direction='clockwise', # Go clockwise
133
- gridcolor='rgba(128,128,128,0.2)',
134
- tickmode='array',
135
- tickvals=[0, 90, 180, 270],
136
- ticktext=['0%', '25%', '50%', '75%'],
137
- showticklabels=True,
138
- tickfont=dict(size=12, color='#888')
139
- ),
140
- bgcolor='rgba(0,0,0,0)'
141
- ),
142
- showlegend=True,
143
- legend=dict(
144
- orientation="v",
145
- yanchor="middle",
146
- y=0.5,
147
- xanchor="left",
148
- x=1.1,
149
- font=dict(size=14, color='white'),
150
- itemwidth=30
151
- ),
152
- paper_bgcolor='rgba(0,0,0,0)',
153
- plot_bgcolor='rgba(0,0,0,0)',
154
- height=450,
155
- margin=dict(l=60, r=180, t=40, b=40)
156
- )
157
-
158
- return fig
159
-
160
-
161
  class ForgeryDetector:
162
- """Main forgery detection pipeline"""
163
-
164
  def __init__(self):
165
  print("Loading models...")
166
-
167
- # Load config
168
- self.config = get_config('config.yaml')
169
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
170
-
171
- # Load segmentation model
172
  self.model = get_model(self.config).to(self.device)
173
- checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
174
- self.model.load_state_dict(checkpoint['model_state_dict'])
175
  self.model.eval()
176
-
177
- # Load classifier
178
  self.classifier = ForgeryClassifier(self.config)
179
- self.classifier.load('models/classifier')
180
-
181
- # Initialize components
182
- self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
183
- self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
184
  self.mask_refiner = get_mask_refiner(self.config)
185
  self.region_extractor = get_region_extractor(self.config)
186
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
187
-
188
- print("✓ Models loaded successfully!")
189
-
190
  def detect(self, image):
191
- """
192
- Detect forgeries in document image or PDF
193
-
194
- Returns:
195
- original_image: Original uploaded image
196
- overlay_image: Image with detection overlay
197
- gauge_dice: Dice score gauge
198
- gauge_accuracy: Accuracy gauge
199
- results_html: Detection results as HTML
200
- """
201
- # Handle file path input (from gr.Image with type="filepath")
202
- if isinstance(image, str):
203
- if image.lower().endswith('.pdf'):
204
- # Handle PDF files
205
- import fitz # PyMuPDF
206
- pdf_document = fitz.open(image)
207
- page = pdf_document[0]
208
- pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
209
- image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
210
- if pix.n == 4:
211
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
212
- pdf_document.close()
213
- else:
214
- # Load image file
215
- image = Image.open(image)
216
- image = np.array(image)
217
-
218
- # Convert PIL to numpy
219
  if isinstance(image, Image.Image):
220
  image = np.array(image)
221
-
222
- # Convert to RGB
223
- if len(image.shape) == 2:
224
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
225
  elif image.shape[2] == 4:
226
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
227
-
228
- original_image = image.copy()
229
-
230
- # Preprocess
231
  preprocessed, _ = self.preprocessor(image, None)
232
-
233
- # Augment
234
  augmented = self.augmentation(preprocessed, None)
235
- image_tensor = augmented['image'].unsqueeze(0).to(self.device)
236
-
237
- # Run localization
238
  with torch.no_grad():
239
  logits, decoder_features = self.model(image_tensor)
240
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
241
-
242
- # Resize probability map to match original image size to avoid index mismatch errors
243
- prob_map_resized = cv2.resize(
244
- prob_map,
245
- (original_image.shape[1], original_image.shape[0]),
246
- interpolation=cv2.INTER_LINEAR
247
- )
248
-
249
- # Refine mask
250
- binary_mask = (prob_map_resized > 0.5).astype(np.uint8)
251
- refined_mask = self.mask_refiner.refine(prob_map_resized, original_size=original_image.shape[:2])
252
-
253
- # Extract regions
254
- regions = self.region_extractor.extract(refined_mask, prob_map_resized, original_image)
255
-
256
- # Classify regions
257
  results = []
258
- for region in regions:
259
- # Extract features
260
  features = self.feature_extractor.extract(
261
- preprocessed,
262
- region['region_mask'],
263
- [f.cpu() for f in decoder_features]
264
  )
265
-
266
- # Reshape features to 2D array
267
  if features.ndim == 1:
268
  features = features.reshape(1, -1)
269
-
270
- # Pad/truncate features to match classifier
271
- expected_features = 526
272
- current_features = features.shape[1]
273
- if current_features < expected_features:
274
- padding = np.zeros((features.shape[0], expected_features - current_features))
275
- features = np.hstack([features, padding])
276
- elif current_features > expected_features:
277
- features = features[:, :expected_features]
278
-
279
- # Classify
280
- predictions, confidences = self.classifier.predict(features)
281
- forgery_type = int(predictions[0])
282
- confidence = float(confidences[0])
283
-
284
- if confidence > 0.6:
285
  results.append({
286
- 'region_id': region['region_id'],
287
- 'bounding_box': region['bounding_box'],
288
- 'forgery_type': CLASS_NAMES[forgery_type],
289
- 'confidence': confidence
290
  })
291
-
292
- # Create visualization
293
- overlay = self._create_overlay(original_image, results)
294
-
295
- # Calculate actual detection metrics from probability map and mask
296
- num_detections = len(results)
297
- avg_confidence = sum(r['confidence'] for r in results) / num_detections if num_detections > 0 else 0
298
-
299
- # Calculate IoU, Precision, Recall from the refined mask and probability map
300
- if num_detections > 0:
301
- # Use resized prob_map to match refined_mask dimensions
302
- high_conf_mask = (prob_map_resized > 0.7).astype(np.uint8)
303
- predicted_positive = np.sum(refined_mask > 0)
304
- high_conf_positive = np.sum(high_conf_mask > 0)
305
-
306
- # Calculate intersection and union
307
- intersection = np.sum((refined_mask > 0) & (high_conf_mask > 0))
308
- union = np.sum((refined_mask > 0) | (high_conf_mask > 0))
309
-
310
- # Calculate metrics
311
- iou = intersection / union if union > 0 else 0
312
- precision = intersection / predicted_positive if predicted_positive > 0 else 0
313
- recall = intersection / high_conf_positive if high_conf_positive > 0 else 0
314
- else:
315
- # No detections - use zeros
316
- iou = 0
317
- precision = 0
318
- recall = 0
319
-
320
- # Create detection metrics gauge with actual values
321
- metrics_gauge = create_detection_metrics_gauge(avg_confidence, iou, precision, recall, num_detections)
322
-
323
- # Create HTML response
324
- results_html = self._create_html_report(results)
325
-
326
- return overlay, metrics_gauge, results_html
327
-
328
- def _create_overlay(self, image, results):
329
- """Create overlay visualization"""
330
- overlay = image.copy()
331
-
332
- for result in results:
333
- bbox = result['bounding_box']
334
- x, y, w, h = bbox
335
-
336
- forgery_type = result['forgery_type']
337
- confidence = result['confidence']
338
-
339
- # Get color
340
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
341
- color = CLASS_COLORS[forgery_id]
342
-
343
- # Draw rectangle
344
- cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
345
-
346
- # Draw label
347
- label = f"{forgery_type}: {confidence:.1%}"
348
- font = cv2.FONT_HERSHEY_SIMPLEX
349
- font_scale = 0.5
350
- thickness = 1
351
- (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
352
-
353
- cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1)
354
- cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness)
355
-
356
- return overlay
357
-
358
- def _create_html_report(self, results):
359
- """Create HTML report with detection results"""
360
- num_detections = len(results)
361
-
362
- if num_detections == 0:
363
- return """
364
- <div style='padding:12px; border:1px solid #5cb85c; border-radius:8px;'>
365
- ✓ <b>No forgery detected.</b><br>
366
- The document appears to be authentic.
367
- </div>
368
- """
369
-
370
- # Calculate statistics
371
- avg_confidence = sum(r['confidence'] for r in results) / num_detections
372
- type_counts = {}
373
  for r in results:
374
- ft = r['forgery_type']
375
- type_counts[ft] = type_counts.get(ft, 0) + 1
376
-
377
- html = f"""
378
- <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
379
- <b>⚠️ Forgery Detected</b><br><br>
380
-
381
- <b>Summary:</b><br>
382
- • Regions detected: {num_detections}<br>
383
- • Average confidence: {avg_confidence*100:.1f}%<br><br>
384
-
385
- <b>Detections:</b><br>
386
- """
387
-
388
- for i, result in enumerate(results, 1):
389
- forgery_type = result['forgery_type']
390
- confidence = result['confidence']
391
- bbox = result['bounding_box']
392
-
393
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
394
- color_rgb = CLASS_COLORS[forgery_id]
395
- color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
396
-
397
- html += f"""
398
- <div style='margin:8px 0; padding:8px; border-left:3px solid {color_hex}; background:rgba(0,0,0,0.02);'>
399
- <b>Region {i}:</b> {forgery_type} ({confidence*100:.1f}%)<br>
400
- <small>Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px</small>
401
- </div>
402
- """
403
-
404
- html += """
405
- </div>
406
- """
407
-
408
- return html
409
-
410
-
411
- # Initialize detector
412
- detector = ForgeryDetector()
413
 
 
 
 
 
 
414
 
415
- def detect_forgery(file_input, webcam_input):
416
- """Gradio interface function - handles unified file upload and webcam"""
417
- try:
418
- # Prioritize file upload, fallback to webcam
419
- if file_input is not None:
420
- file_path = file_input
421
- elif webcam_input is not None:
422
- file_path = webcam_input
423
- else:
424
- empty_html = "<div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>❌ <b>No file uploaded.</b></div>"
425
- return None, None, empty_html
426
-
427
- # Detect forgeries
428
- overlay, metrics_gauge, results_html = detector.detect(file_path)
429
-
430
- return overlay, metrics_gauge, results_html
431
-
432
- except Exception as e:
433
- import traceback
434
- error_details = traceback.format_exc()
435
- print(f"Error: {error_details}")
436
- error_html = f"""
437
- <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
438
- ❌ <b>Error:</b> {str(e)}
439
- </div>
440
- """
441
- return None, None, error_html
442
-
443
-
444
- # Custom CSS - subtle styling
445
- custom_css = """
446
- .predict-btn {
447
- background-color: #4169E1 !important;
448
- color: white !important;
449
- }
450
- .clear-btn {
451
- background-color: #6A89A7 !important;
452
- color: white !important;
453
- }
454
- """
455
 
456
- # Create Gradio interface
457
- with gr.Blocks(css=custom_css) as demo:
458
-
459
- gr.Markdown(
460
- """
461
- # 📄 Document Forgery Detection
462
- Upload a document image or PDF to detect and classify forgeries using deep learning. The system combines MobileNetV3-UNet for precise localization and LightGBM for classification, identifying Copy-Move, Splicing, and Text Substitution manipulations with detailed confidence scores and bounding boxes. Trained on 140K samples for robust performance.
463
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  )
465
- gr.Markdown("---")
466
-
 
 
 
 
 
 
467
  with gr.Row():
468
- with gr.Column(scale=1):
469
- gr.Markdown("### Upload Document")
470
-
471
- # Single unified input that accepts images and PDFs
472
- input_file = gr.File(
473
- label="📤 Upload Image/PDF or 📷 Use Webcam",
474
- file_types=["image", ".pdf"],
475
- type="filepath"
476
- )
477
-
478
- # Hidden webcam component for capture functionality
479
- input_webcam = gr.Image(
480
- label="Webcam Capture",
481
- type="filepath",
482
- sources=["webcam"],
483
- visible=False
484
- )
485
-
486
- # Button to trigger webcam
487
- with gr.Row():
488
- webcam_btn = gr.Button("📷 Open Webcam", size="sm")
489
-
490
  with gr.Row():
491
- clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn")
492
- analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn")
493
-
494
- with gr.Column(scale=1):
495
- gr.Markdown("### Information")
496
- gr.HTML(
497
- """
498
- <div style='padding:16px; border:1px solid #ccc; border-radius:8px; background:var(--background-fill-primary);'>
499
- <p style='margin-top:0;'><b>Supported formats:</b></p>
500
- <ul style='margin:8px 0; padding-left:20px;'>
501
- <li>Images: JPG, PNG, BMP, TIFF, WebP</li>
502
- <li>PDF: First page analyzed</li>
503
- </ul>
504
-
505
- <p style='margin-bottom:4px;'><b>Forgery types:</b></p>
506
- <ul style='margin:8px 0; padding-left:20px;'>
507
- <li style='color:#d9534f;'><b>Copy-Move:</b> <span style='color:inherit;'>Duplicated regions</span></li>
508
- <li style='color:#4169E1;'><b>Splicing:</b> <span style='color:inherit;'>Mixed sources</span></li>
509
- <li style='color:#5cb85c;'><b>Text Substitution:</b> <span style='color:inherit;'>Modified text</span></li>
510
- </ul>
511
- </div>
512
- """
513
- )
514
-
515
- with gr.Column(scale=2):
516
- gr.Markdown("### Detection Results")
517
- output_image = gr.Image(label="Detected Forgeries", type="numpy")
518
-
519
- gr.Markdown("---")
520
 
521
- with gr.Row():
522
- with gr.Column(scale=1):
523
- gr.Markdown("### Analysis Report")
524
- output_html = gr.HTML(
525
- value="<i>No analysis yet. Upload a document and click Analyze.</i>"
526
- )
527
-
528
- with gr.Column(scale=1):
529
- gr.Markdown("### Detection Metrics")
530
- metrics_gauge = gr.Plot(label="Concentric Metrics Gauge")
531
-
532
- gr.Markdown("---")
533
-
534
- with gr.Row():
535
- with gr.Column(scale=1):
536
- gr.Markdown("### Model Architecture")
537
- gr.HTML(
538
- """
539
- <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'>
540
- <p style="margin:0 0 0px 0; font-size:1.05em;"><b>Localization:</b> MobileNetV3-Small + UNet</p>
541
- <p style='margin:0 20px 5px 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;'>Dice: 62.12% | IoU: 45.06% | Precision: 70.77% | Recall: 55.36%</p>
542
-
543
- <p style="margin:0 0 0 0; font-size:1.05em;"><b>Classification:</b> LightGBM with 526 features</p>
544
- <p style="margin:0 20px 0 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;">Train Accuracy: 90.53% | Val Accuracy: 88.97%</p>
545
-
546
- <p style='margin-top:5px; margin-bottom:0; font-size:1.05em;'><b>Training:</b> 140K samples from DocTamper dataset</p>
547
- </div>
548
- """
549
- )
550
-
551
- with gr.Column(scale=1):
552
- gr.Markdown("### Model Performance")
553
- gr.HTML(
554
- f"""
555
- <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'>
556
- <p style='margin-top:0; margin-bottom:12px;'><b>Trained Model Performance:</b></p>
557
-
558
- <b>Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.2f}%</b>
559
- <div style='width:100%; background:#333; height:12px; border-radius:6px; margin-bottom:12px;'>
560
- <div style='width:{MODEL_METRICS['segmentation']['dice']*100:.1f}%; background:#4169E1; height:12px; border-radius:6px;'></div>
561
- </div>
562
-
563
- <b>Classification Accuracy: {MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%</b>
564
- <div style='width:100%; background:#333; height:12px; border-radius:6px;'>
565
- <div style='width:{MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%; background:#5cb85c; height:12px; border-radius:6px;'></div>
566
- </div>
567
- </div>
568
- """
569
- )
570
-
571
- # Event handlers
572
- # Toggle webcam visibility
573
- webcam_btn.click(
574
- fn=lambda: gr.update(visible=True),
575
- inputs=None,
576
- outputs=[input_webcam]
577
- )
578
-
579
- analyze_btn.click(
580
- fn=detect_forgery,
581
- inputs=[input_file, input_webcam],
582
- outputs=[output_image, metrics_gauge, output_html]
583
- )
584
-
585
- clear_btn.click(
586
- fn=lambda: (None, None, gr.update(visible=False), None, None, "<i>No analysis yet. Upload a document and click Analyze.</i>"),
587
- inputs=None,
588
- outputs=[input_file, input_webcam, input_webcam, output_image, metrics_gauge, output_html]
589
- )
590
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
  if __name__ == "__main__":
593
  demo.launch()
 
1
  """
2
+ Document Forgery Detection Professional Gradio Dashboard
3
+ Hugging Face Spaces Deployment
 
4
  """
5
 
6
  import gradio as gr
 
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
+ import plotly.graph_objects as go
12
  from pathlib import Path
13
  import sys
14
+ import json
 
15
 
16
+ # -------------------------------------------------
17
+ # PATH SETUP
18
+ # -------------------------------------------------
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
 
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
+ # -------------------------------------------------
30
+ # CONSTANTS
31
+ # -------------------------------------------------
32
+ CLASS_NAMES = {0: "Copy-Move", 1: "Splicing", 2: "Generation"}
33
  CLASS_COLORS = {
34
+ 0: (255, 0, 0),
35
+ 1: (0, 255, 0),
36
+ 2: (0, 0, 255),
37
  }
38
 
39
+ # -------------------------------------------------
40
+ # FORGERY DETECTOR (UNCHANGED CORE LOGIC)
41
+ # -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class ForgeryDetector:
 
 
43
  def __init__(self):
44
  print("Loading models...")
45
+
46
+ self.config = get_config("config.yaml")
47
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+
 
 
49
  self.model = get_model(self.config).to(self.device)
50
+ checkpoint = torch.load("models/best_doctamper.pth", map_location=self.device)
51
+ self.model.load_state_dict(checkpoint["model_state_dict"])
52
  self.model.eval()
53
+
 
54
  self.classifier = ForgeryClassifier(self.config)
55
+ self.classifier.load("models/classifier")
56
+
57
+ self.preprocessor = DocumentPreprocessor(self.config, "doctamper")
58
+ self.augmentation = DatasetAwareAugmentation(self.config, "doctamper", is_training=False)
 
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
+
63
+ print("✓ Models loaded")
64
+
65
  def detect(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if isinstance(image, Image.Image):
67
  image = np.array(image)
68
+
69
+ if image.ndim == 2:
 
70
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
71
  elif image.shape[2] == 4:
72
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
73
+
74
+ original = image.copy()
75
+
 
76
  preprocessed, _ = self.preprocessor(image, None)
 
 
77
  augmented = self.augmentation(preprocessed, None)
78
+ image_tensor = augmented["image"].unsqueeze(0).to(self.device)
79
+
 
80
  with torch.no_grad():
81
  logits, decoder_features = self.model(image_tensor)
82
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
83
+
84
+ binary = (prob_map > 0.5).astype(np.uint8)
85
+ refined = self.mask_refiner.refine(binary, original_size=original.shape[:2])
86
+ regions = self.region_extractor.extract(refined, prob_map, original)
87
+
 
 
 
 
 
 
 
 
 
 
 
88
  results = []
89
+ for r in regions:
 
90
  features = self.feature_extractor.extract(
91
+ preprocessed, r["region_mask"], [f.cpu() for f in decoder_features]
 
 
92
  )
93
+
 
94
  if features.ndim == 1:
95
  features = features.reshape(1, -1)
96
+
97
+ if features.shape[1] != 526:
98
+ pad = max(0, 526 - features.shape[1])
99
+ features = np.pad(features, ((0, 0), (0, pad)))[:, :526]
100
+
101
+ pred, conf = self.classifier.predict(features)
102
+ if conf[0] > 0.6:
 
 
 
 
 
 
 
 
 
103
  results.append({
104
+ "bounding_box": r["bounding_box"],
105
+ "forgery_type": CLASS_NAMES[int(pred[0])],
106
+ "confidence": float(conf[0]),
 
107
  })
108
+
109
+ overlay = self._draw_overlay(original, results)
110
+
111
+ return overlay, {
112
+ "num_detections": len(results),
113
+ "detections": results,
114
+ }
115
+
116
+ def _draw_overlay(self, image, results):
117
+ out = image.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  for r in results:
119
+ x, y, w, h = r["bounding_box"]
120
+ fid = [k for k, v in CLASS_NAMES.items() if v == r["forgery_type"]][0]
121
+ color = CLASS_COLORS[fid]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
124
+ label = f"{r['forgery_type']} ({r['confidence']*100:.1f}%)"
125
+ cv2.putText(out, label, (x, y - 6),
126
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
127
+ return out
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
+ detector = ForgeryDetector()
131
+
132
+ # -------------------------------------------------
133
+ # METRIC VISUALS
134
+ # -------------------------------------------------
135
+ def gauge(value, title):
136
+ fig = go.Figure(go.Indicator(
137
+ mode="gauge+number",
138
+ value=value,
139
+ title={"text": title},
140
+ gauge={"axis": {"range": [0, 100]}, "bar": {"color": "#2563eb"}}
141
+ ))
142
+ fig.update_layout(height=240, margin=dict(t=40, b=20))
143
+ return fig
144
+
145
+ # -------------------------------------------------
146
+ # GRADIO CALLBACK
147
+ # -------------------------------------------------
148
+ def run_detection(file):
149
+ image = Image.open(file.name)
150
+ overlay, result = detector.detect(image)
151
+
152
+ avg_conf = (
153
+ sum(d["confidence"] for d in result["detections"]) / max(1, result["num_detections"])
154
+ ) * 100
155
+
156
+ return (
157
+ overlay,
158
+ result,
159
+ gauge(75, "Localization Dice (%)"),
160
+ gauge(92, "Classifier Accuracy (%)"),
161
+ gauge(avg_conf, "Avg Detection Confidence (%)"),
162
  )
163
+
164
+ # -------------------------------------------------
165
+ # UI
166
+ # -------------------------------------------------
167
+ with gr.Blocks(theme=gr.themes.Soft(), title="Document Forgery Detection") as demo:
168
+
169
+ gr.Markdown("# 📄 Document Forgery Detection System")
170
+
171
  with gr.Row():
172
+ file_input = gr.File(label="Upload Document (Image/PDF)")
173
+ detect_btn = gr.Button("Run Detection", variant="primary")
174
+
175
+ output_img = gr.Image(label="Forgery Localization Result", type="numpy")
176
+
177
+ with gr.Tabs():
178
+ with gr.Tab("📊 Metrics"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  with gr.Row():
180
+ dice_plot = gr.Plot()
181
+ acc_plot = gr.Plot()
182
+ conf_plot = gr.Plot()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
+ with gr.Tab("🧾 Details"):
185
+ json_out = gr.JSON()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ with gr.Tab("👥 Team"):
188
+ gr.Markdown("""
189
+ **Document Forgery Detection Project**
190
+
191
+ - Krishnanandhaa — Model & Training
192
+ - Teammate 1 — Feature Engineering
193
+ - Teammate 2 — Evaluation
194
+ - Teammate 3 — Deployment
195
+
196
+ *Collaborators are added via Hugging Face Space settings.*
197
+ """)
198
+
199
+ detect_btn.click(
200
+ run_detection,
201
+ inputs=file_input,
202
+ outputs=[output_img, json_out, dice_plot, acc_plot, conf_plot]
203
+ )
204
 
205
  if __name__ == "__main__":
206
  demo.launch()