JKrishnanandhaa commited on
Commit
d192120
·
verified ·
1 Parent(s): 86140ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -651
app.py CHANGED
@@ -1,698 +1,253 @@
1
  """
2
- Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
-
4
- This app provides a web interface for detecting and classifying document forgeries.
5
  """
6
 
7
- import gradio as gr
8
- import torch
9
  import cv2
10
  import numpy as np
11
- from PIL import Image
12
- import json
13
- from pathlib import Path
14
- import sys
15
- from typing import Dict, List, Tuple
16
- import plotly.graph_objects as go
17
-
18
- # Add src to path
19
- sys.path.insert(0, str(Path(__file__).parent))
20
-
21
- from src.models import get_model
22
- from src.config import get_config
23
- from src.data.preprocessing import DocumentPreprocessor
24
- from src.data.augmentation import DatasetAwareAugmentation
25
- from src.features.region_extraction import get_mask_refiner, get_region_extractor
26
- from src.features.feature_extraction import get_feature_extractor
27
- from src.training.classifier import ForgeryClassifier
28
-
29
- # Class names
30
- CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'}
31
- CLASS_COLORS = {
32
- 0: (217, 83, 79), # #d9534f - Muted red
33
- 1: (92, 184, 92), # #5cb85c - Muted green
34
- 2: (65, 105, 225) # #4169E1 - Royal blue
35
- }
36
 
37
- # Actual model performance metrics
38
- MODEL_METRICS = {
39
- 'segmentation': {
40
- 'dice': 0.6212,
41
- 'iou': 0.4506,
42
- 'precision': 0.7077,
43
- 'recall': 0.5536
44
- },
45
- 'classification': {
46
- 'overall_accuracy': 0.8897,
47
- 'per_class': {
48
- 'copy_move': 0.92,
49
- 'splicing': 0.85,
50
- 'generation': 0.90
51
- }
52
- }
53
- }
54
 
55
-
56
- def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure:
57
- """Create a subtle radial gauge chart"""
58
- fig = go.Figure(go.Indicator(
59
- mode="gauge+number",
60
- value=value * 100,
61
- domain={'x': [0, 1], 'y': [0, 1]},
62
- title={'text': title, 'font': {'size': 14}},
63
- number={'suffix': '%', 'font': {'size': 24}},
64
- gauge={
65
- 'axis': {'range': [0, 100], 'tickwidth': 1},
66
- 'bar': {'color': '#4169E1', 'thickness': 0.7},
67
- 'bgcolor': 'rgba(0,0,0,0)',
68
- 'borderwidth': 0,
69
- 'steps': [
70
- {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'},
71
- {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'},
72
- {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'}
73
- ]
74
- }
75
- ))
76
-
77
- fig.update_layout(
78
- paper_bgcolor='rgba(0,0,0,0)',
79
- plot_bgcolor='rgba(0,0,0,0)',
80
- height=200,
81
- margin=dict(l=20, r=20, t=40, b=20)
82
- )
83
 
84
- return fig
85
-
86
-
87
- def create_detection_metrics_gauge(avg_confidence: float, iou: float, precision: float, recall: float, num_detections: int) -> go.Figure:
88
- """Create a high-fidelity radial bar chart (concentric rings)"""
89
-
90
- # Calculate percentages (0-100)
91
- metrics = [
92
- {'name': 'Confidence', 'val': avg_confidence * 100 if num_detections > 0 else 0, 'color': '#4169E1', 'base': 80},
93
- {'name': 'Precision', 'val': precision * 100, 'color': '#5cb85c', 'base': 60},
94
- {'name': 'Recall', 'val': recall * 100, 'color': '#f0ad4e', 'base': 40},
95
- {'name': 'IoU', 'val': iou * 100, 'color': '#d9534f', 'base': 20}
96
- ]
97
-
98
- fig = go.Figure()
99
-
100
- for m in metrics:
101
- # 1. Add background track (faint gray ring)
102
- fig.add_trace(go.Barpolar(
103
- r=[15],
104
- theta=[180],
105
- width=[360],
106
- base=m['base'],
107
- marker_color='rgba(128,128,128,0.1)',
108
- hoverinfo='none',
109
- showlegend=False
110
- ))
111
-
112
- # 2. Add the actual metric bar (the colored arc)
113
- # 100% = 360 degrees
114
- angle_width = m['val'] * 3.6
115
- fig.add_trace(go.Barpolar(
116
- r=[15],
117
- theta=[angle_width / 2],
118
- width=[angle_width],
119
- base=m['base'],
120
- name=f"{m['name']}: {m['val']:.1f}%",
121
- marker_color=m['color'],
122
- marker_line_width=0,
123
- hoverinfo='name'
124
- ))
125
-
126
- fig.update_layout(
127
- polar=dict(
128
- hole=0.1,
129
- radialaxis=dict(visible=False, range=[0, 100]),
130
- angularaxis=dict(
131
- rotation=90, # Start at 12 o'clock
132
- direction='clockwise', # Go clockwise
133
- gridcolor='rgba(128,128,128,0.2)',
134
- tickmode='array',
135
- tickvals=[0, 90, 180, 270],
136
- ticktext=['0%', '25%', '50%', '75%'],
137
- showticklabels=True,
138
- tickfont=dict(size=12, color='#888')
139
- ),
140
- bgcolor='rgba(0,0,0,0)'
141
- ),
142
- showlegend=True,
143
- legend=dict(
144
- orientation="v",
145
- yanchor="middle",
146
- y=0.5,
147
- xanchor="left",
148
- x=1.1,
149
- font=dict(size=14, color='white'),
150
- itemwidth=30
151
- ),
152
- paper_bgcolor='rgba(0,0,0,0)',
153
- plot_bgcolor='rgba(0,0,0,0)',
154
- height=450,
155
- margin=dict(l=60, r=180, t=40, b=40)
156
- )
157
-
158
- return fig
159
-
160
-
161
- class ForgeryDetector:
162
- """Main forgery detection pipeline"""
163
-
164
- def __init__(self):
165
- print("Loading models...")
166
-
167
- # Load config
168
- self.config = get_config('config.yaml')
169
- self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
170
 
171
- # Load segmentation model
172
- self.model = get_model(self.config).to(self.device)
173
- checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
174
- self.model.load_state_dict(checkpoint['model_state_dict'])
175
- self.model.eval()
 
176
 
177
- # Load classifier
178
- self.classifier = ForgeryClassifier(self.config)
179
- self.classifier.load('models/classifier')
 
180
 
181
- # Initialize components
182
- self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
183
- self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
184
- self.mask_refiner = get_mask_refiner(self.config)
185
- self.region_extractor = get_region_extractor(self.config)
186
- self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
187
 
188
- print(" Models loaded successfully!")
 
189
 
190
- def detect(self, image):
 
 
191
  """
192
- Detect forgeries in document image or PDF
 
 
 
 
193
 
194
  Returns:
195
- original_image: Original uploaded image
196
- overlay_image: Image with detection overlay
197
- gauge_dice: Dice score gauge
198
- gauge_accuracy: Accuracy gauge
199
- results_html: Detection results as HTML
200
  """
201
- # Handle file path input (from gr.Image with type="filepath")
202
- if isinstance(image, str):
203
- if image.lower().endswith(('.doc', '.docx')):
204
- # Handle Word documents - multiple fallback strategies
205
- import tempfile
206
- import os
207
- import subprocess
208
-
209
- temp_pdf = None
210
- try:
211
- # Strategy 1: Try docx2pdf (Windows with MS Word)
212
- try:
213
- from docx2pdf import convert
214
- temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
215
- temp_pdf.close()
216
- convert(image, temp_pdf.name)
217
- pdf_path = temp_pdf.name
218
- except Exception as e1:
219
- # Strategy 2: Try LibreOffice (Linux/Mac)
220
- try:
221
- temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix='.pdf')
222
- temp_pdf.close()
223
- subprocess.run([
224
- 'libreoffice', '--headless', '--convert-to', 'pdf',
225
- '--outdir', os.path.dirname(temp_pdf.name),
226
- image
227
- ], check=True, capture_output=True)
228
-
229
- # LibreOffice creates file with original name + .pdf
230
- base_name = os.path.splitext(os.path.basename(image))[0]
231
- generated_pdf = os.path.join(os.path.dirname(temp_pdf.name), f"{base_name}.pdf")
232
-
233
- if os.path.exists(generated_pdf):
234
- os.rename(generated_pdf, temp_pdf.name)
235
- pdf_path = temp_pdf.name
236
- else:
237
- raise Exception("LibreOffice conversion failed")
238
- except Exception as e2:
239
- # Strategy 3: Extract text and create simple image
240
- from docx import Document
241
- doc = Document(image)
242
-
243
- # Extract text
244
- text_lines = []
245
- for para in doc.paragraphs[:40]: # First 40 paragraphs
246
- if para.text.strip():
247
- text_lines.append(para.text[:100]) # Max 100 chars per line
248
-
249
- # Create image with text
250
- img_height = 1400
251
- img_width = 1000
252
- image = np.ones((img_height, img_width, 3), dtype=np.uint8) * 255
253
-
254
- y_offset = 60
255
- for line in text_lines[:35]:
256
- cv2.putText(image, line, (40, y_offset),
257
- cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 0, 0), 1, cv2.LINE_AA)
258
- y_offset += 35
259
-
260
- # Skip to end - image is ready
261
- pdf_path = None
262
-
263
- # If we got a PDF, convert it to image
264
- if pdf_path and os.path.exists(pdf_path):
265
- import fitz
266
- pdf_document = fitz.open(pdf_path)
267
- page = pdf_document[0]
268
- pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
269
- image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
270
- if pix.n == 4:
271
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
272
- pdf_document.close()
273
- os.unlink(pdf_path)
274
-
275
- except Exception as e:
276
- raise ValueError(f"Could not process Word document. Please convert to PDF or image first. Error: {str(e)}")
277
- finally:
278
- # Clean up temp file if it exists
279
- if temp_pdf and os.path.exists(temp_pdf.name):
280
- try:
281
- os.unlink(temp_pdf.name)
282
- except:
283
- pass
284
-
285
- elif image.lower().endswith('.pdf'):
286
- # Handle PDF files
287
- import fitz # PyMuPDF
288
- pdf_document = fitz.open(image)
289
- page = pdf_document[0]
290
- pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
291
- image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
292
- if pix.n == 4:
293
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
294
- pdf_document.close()
295
- else:
296
- # Load image file
297
- image = Image.open(image)
298
- image = np.array(image)
299
-
300
- # Convert PIL to numpy
301
- if isinstance(image, Image.Image):
302
- image = np.array(image)
303
-
304
- # Convert to RGB
305
- if len(image.shape) == 2:
306
- image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
307
- elif image.shape[2] == 4:
308
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
309
-
310
- original_image = image.copy()
311
-
312
- # Preprocess
313
- preprocessed, _ = self.preprocessor(image, None)
314
 
315
- # Augment
316
- augmented = self.augmentation(preprocessed, None)
317
- image_tensor = augmented['image'].unsqueeze(0).to(self.device)
318
-
319
- # Run localization
320
- with torch.no_grad():
321
- logits, decoder_features = self.model(image_tensor)
322
- prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
323
 
324
- # Resize probability map to match original image size to avoid index mismatch errors
325
- prob_map_resized = cv2.resize(
326
- prob_map,
327
- (original_image.shape[1], original_image.shape[0]),
328
- interpolation=cv2.INTER_LINEAR
329
  )
 
330
 
331
- # Refine mask
332
- binary_mask = (prob_map_resized > 0.5).astype(np.uint8)
333
- refined_mask = self.mask_refiner.refine(prob_map_resized, original_size=original_image.shape[:2])
334
 
335
- # Ensure refined_mask matches prob_map_resized dimensions
336
- if refined_mask.shape != prob_map_resized.shape:
337
- refined_mask = cv2.resize(
338
- refined_mask,
339
- (prob_map_resized.shape[1], prob_map_resized.shape[0]),
340
  interpolation=cv2.INTER_NEAREST
341
  )
342
 
343
- # Safety check: Ensure prob_map_resized and refined_mask have same dimensions (fallback)
344
- if prob_map_resized.shape != refined_mask.shape:
345
- prob_map_resized = cv2.resize(
346
- prob_map_resized,
347
- (refined_mask.shape[1], refined_mask.shape[0]),
348
- interpolation=cv2.INTER_LINEAR
349
- )
350
-
351
- # Extract regions
352
- regions = self.region_extractor.extract(refined_mask, prob_map_resized, original_image)
353
-
354
- # Classify regions
355
- results = []
356
- for region in regions:
357
- # Get decoder features and handle shape
358
- df = decoder_features[0].cpu() # Get first decoder feature
359
-
360
- # Remove batch dimension if present: [1, C, H, W] -> [C, H, W]
361
- if df.ndim == 4:
362
- df = df.squeeze(0)
363
-
364
- # Now df should be [C, H, W]
365
- _, fh, fw = df.shape
366
-
367
- region_mask = region['region_mask']
368
- if region_mask.shape != (fh, fw):
369
- region_mask = cv2.resize(
370
- region_mask.astype(np.uint8),
371
- (fw, fh),
372
- interpolation=cv2.INTER_NEAREST
373
- )
374
-
375
- region_mask = region_mask.astype(bool)
376
-
377
- # Extract features
378
- features = self.feature_extractor.extract(
379
- preprocessed,
380
- region['region_mask'],
381
- [f.cpu() for f in decoder_features]
382
- )
383
-
384
- # Reshape features to 2D array
385
- if features.ndim == 1:
386
- features = features.reshape(1, -1)
387
-
388
- # Pad/truncate features to match classifier
389
- expected_features = 526
390
- current_features = features.shape[1]
391
- if current_features < expected_features:
392
- padding = np.zeros((features.shape[0], expected_features - current_features))
393
- features = np.hstack([features, padding])
394
- elif current_features > expected_features:
395
- features = features[:, :expected_features]
396
-
397
- # Classify
398
- predictions, confidences = self.classifier.predict(features)
399
- forgery_type = int(predictions[0])
400
- confidence = float(confidences[0])
401
-
402
- if confidence > 0.6:
403
- results.append({
404
- 'region_id': region['region_id'],
405
- 'bounding_box': region['bounding_box'],
406
- 'forgery_type': CLASS_NAMES[forgery_type],
407
- 'confidence': confidence
408
- })
409
-
410
- # Create visualization
411
- overlay = self._create_overlay(original_image, results)
412
-
413
- # Calculate actual detection metrics from probability map and mask
414
- num_detections = len(results)
415
- avg_confidence = sum(r['confidence'] for r in results) / num_detections if num_detections > 0 else 0
416
-
417
- # Calculate IoU, Precision, Recall from the refined mask and probability map
418
- if num_detections > 0:
419
- # Use resized prob_map to match refined_mask dimensions
420
- high_conf_mask = (prob_map_resized > 0.7).astype(np.uint8)
421
- predicted_positive = np.sum(refined_mask > 0)
422
- high_conf_positive = np.sum(high_conf_mask > 0)
423
-
424
- # Calculate intersection and union
425
- intersection = np.sum((refined_mask > 0) & (high_conf_mask > 0))
426
- union = np.sum((refined_mask > 0) | (high_conf_mask > 0))
427
-
428
- # Calculate metrics
429
- iou = intersection / union if union > 0 else 0
430
- precision = intersection / predicted_positive if predicted_positive > 0 else 0
431
- recall = intersection / high_conf_positive if high_conf_positive > 0 else 0
432
- else:
433
- # No detections - use zeros
434
- iou = 0
435
- precision = 0
436
- recall = 0
437
-
438
- # Create detection metrics gauge with actual values
439
- metrics_gauge = create_detection_metrics_gauge(avg_confidence, iou, precision, recall, num_detections)
440
-
441
- # Create HTML response
442
- results_html = self._create_html_report(results)
443
-
444
- return overlay, metrics_gauge, results_html
445
 
446
- def _create_overlay(self, image, results):
447
- """Create overlay visualization"""
448
- overlay = image.copy()
449
 
450
- for result in results:
451
- bbox = result['bounding_box']
452
- x, y, w, h = bbox
453
-
454
- forgery_type = result['forgery_type']
455
- confidence = result['confidence']
456
-
457
- # Get color
458
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
459
- color = CLASS_COLORS[forgery_id]
460
-
461
- # Draw rectangle
462
- cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
463
-
464
- # Draw label
465
- label = f"{forgery_type}: {confidence:.1%}"
466
- font = cv2.FONT_HERSHEY_SIMPLEX
467
- font_scale = 0.5
468
- thickness = 1
469
- (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
470
-
471
- cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1)
472
- cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness)
473
 
474
- return overlay
475
-
476
- def _create_html_report(self, results):
477
- """Create HTML report with detection results"""
478
- num_detections = len(results)
 
479
 
480
- if num_detections == 0:
481
- return """
482
- <div style='padding:12px; border:1px solid #5cb85c; border-radius:8px;'>
483
- ✓ <b>No forgery detected.</b><br>
484
- The document appears to be authentic.
485
- </div>
486
- """
487
 
488
- # Calculate statistics
489
- avg_confidence = sum(r['confidence'] for r in results) / num_detections
490
- type_counts = {}
491
- for r in results:
492
- ft = r['forgery_type']
493
- type_counts[ft] = type_counts.get(ft, 0) + 1
494
 
495
- html = f"""
496
- <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
497
- <b>⚠️ Forgery Detected</b><br><br>
498
-
499
- <b>Summary:</b><br>
500
- • Regions detected: {num_detections}<br>
501
- • Average confidence: {avg_confidence*100:.1f}%<br><br>
502
 
503
- <b>Detections:</b><br>
504
- """
505
 
506
- for i, result in enumerate(results, 1):
507
- forgery_type = result['forgery_type']
508
- confidence = result['confidence']
509
- bbox = result['bounding_box']
510
-
511
- forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
512
- color_rgb = CLASS_COLORS[forgery_id]
513
- color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
514
-
515
- html += f"""
516
- <div style='margin:8px 0; padding:8px; border-left:3px solid {color_hex}; background:rgba(0,0,0,0.02);'>
517
- <b>Region {i}:</b> {forgery_type} ({confidence*100:.1f}%)<br>
518
- <small>Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px</small>
519
- </div>
520
- """
521
-
522
- html += """
523
- </div>
524
- """
525
-
526
- return html
527
 
528
 
529
- # Initialize detector
530
- detector = ForgeryDetector()
531
-
532
-
533
- def detect_forgery(file, webcam):
534
- """Gradio interface function - handles file uploads and webcam capture"""
535
- try:
536
- # Use whichever input has data
537
- source = file if file is not None else webcam
538
 
539
- if source is None:
540
- empty_html = "<div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>❌ <b>No input provided.</b> Please upload a file or use webcam.</div>"
541
- return None, None, empty_html
 
 
 
 
 
 
 
 
 
 
 
542
 
543
- # Detect forgeries
544
- overlay, metrics_gauge, results_html = detector.detect(source)
 
 
545
 
546
- return overlay, metrics_gauge, results_html
547
-
548
- except Exception as e:
549
- import traceback
550
- error_details = traceback.format_exc()
551
- print(f"Error: {error_details}")
552
- error_html = f"""
553
- <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
554
- ❌ <b>Error:</b> {str(e)}
555
- </div>
556
  """
557
- return None, None, error_html
558
-
559
-
560
- # Custom CSS - subtle styling
561
- custom_css = """
562
- .predict-btn {
563
- background-color: #4169E1 !important;
564
- color: white !important;
565
- }
566
- .clear-btn {
567
- background-color: #6A89A7 !important;
568
- color: white !important;
569
- }
570
- """
571
-
572
- # Create Gradio interface
573
- with gr.Blocks(css=custom_css) as demo:
574
-
575
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  """
577
- # 📄 Document Forgery Detection
578
- Upload a document image or PDF to detect and classify forgeries using deep learning. The system combines MobileNetV3-UNet for precise localization and LightGBM for classification, identifying Copy-Move, Splicing, and Text Substitution manipulations with detailed confidence scores and bounding boxes. Trained on 140K samples for robust performance.
 
 
 
 
 
 
 
579
  """
580
- )
581
- gr.Markdown("---")
582
-
583
- with gr.Row():
584
- with gr.Column(scale=1):
585
- gr.Markdown("### Upload Document")
586
-
587
- with gr.Tabs():
588
- with gr.Tab("📤 Upload File"):
589
- input_file = gr.File(
590
- label="Upload Image, PDF, or Document",
591
- file_types=["image", ".pdf", ".doc", ".docx"],
592
- type="filepath"
593
- )
594
-
595
- with gr.Tab("📷 Webcam"):
596
- input_webcam = gr.Image(
597
- label="Capture from Webcam",
598
- type="filepath",
599
- sources=["webcam"]
600
- )
601
-
602
- with gr.Row():
603
- clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn")
604
- analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn")
605
 
606
- with gr.Column(scale=1):
607
- gr.Markdown("### Information")
608
- gr.HTML(
609
- """
610
- <div style='padding:16px; border:1px solid #ccc; border-radius:8px; background:var(--background-fill-primary);'>
611
- <p style='margin-top:0;'><b>Supported formats:</b></p>
612
- <ul style='margin:8px 0; padding-left:20px;'>
613
- <li>Images: JPG, PNG, BMP, TIFF, WebP</li>
614
- <li>PDF: First page analyzed</li>
615
- </ul>
616
-
617
- <p style='margin-bottom:4px;'><b>Forgery types:</b></p>
618
- <ul style='margin:8px 0; padding-left:20px;'>
619
- <li style='color:#d9534f;'><b>Copy-Move:</b> <span style='color:inherit;'>Duplicated regions</span></li>
620
- <li style='color:#4169E1;'><b>Splicing:</b> <span style='color:inherit;'>Mixed sources</span></li>
621
- <li style='color:#5cb85c;'><b>Text Substitution:</b> <span style='color:inherit;'>Modified text</span></li>
622
- </ul>
623
- </div>
624
- """
625
- )
626
 
627
- with gr.Column(scale=2):
628
- gr.Markdown("### Detection Results")
629
- output_image = gr.Image(label="Detected Forgeries", type="numpy")
630
-
631
- gr.Markdown("---")
632
-
633
- with gr.Row():
634
- with gr.Column(scale=1):
635
- gr.Markdown("### Analysis Report")
636
- output_html = gr.HTML(
637
- value="<i>No analysis yet. Upload a document and click Analyze.</i>"
638
- )
639
 
640
- with gr.Column(scale=1):
641
- gr.Markdown("### Detection Metrics")
642
- metrics_gauge = gr.Plot(label="Concentric Metrics Gauge")
643
-
644
- gr.Markdown("---")
645
-
646
- with gr.Row():
647
- with gr.Column(scale=1):
648
- gr.Markdown("### Model Architecture")
649
- gr.HTML(
650
- """
651
- <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'>
652
- <p style="margin:0 0 0px 0; font-size:1.05em;"><b>Localization:</b> MobileNetV3-Small + UNet</p>
653
- <p style='margin:0 20px 5px 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;'>Dice: 62.12% | IoU: 45.06% | Precision: 70.77% | Recall: 55.36%</p>
654
-
655
- <p style="margin:0 0 0 0; font-size:1.05em;"><b>Classification:</b> LightGBM with 526 features</p>
656
- <p style="margin:0 20px 0 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;">Train Accuracy: 90.53% | Val Accuracy: 88.97%</p>
657
 
658
- <p style='margin-top:5px; margin-bottom:0; font-size:1.05em;'><b>Training:</b> 140K samples from DocTamper dataset</p>
659
- </div>
660
- """
661
- )
662
-
663
- with gr.Column(scale=1):
664
- gr.Markdown("### Model Performance")
665
- gr.HTML(
666
- f"""
667
- <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'>
668
- <p style='margin-top:0; margin-bottom:12px;'><b>Trained Model Performance:</b></p>
669
-
670
- <b>Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.2f}%</b>
671
- <div style='width:100%; background:#333; height:12px; border-radius:6px; margin-bottom:12px;'>
672
- <div style='width:{MODEL_METRICS['segmentation']['dice']*100:.1f}%; background:#4169E1; height:12px; border-radius:6px;'></div>
673
- </div>
674
-
675
- <b>Classification Accuracy: {MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%</b>
676
- <div style='width:100%; background:#333; height:12px; border-radius:6px;'>
677
- <div style='width:{MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%; background:#5cb85c; height:12px; border-radius:6px;'></div>
678
- </div>
679
- </div>
680
- """
681
- )
682
-
683
- # Event handlers
684
- analyze_btn.click(
685
- fn=detect_forgery,
686
- inputs=[input_file, input_webcam],
687
- outputs=[output_image, metrics_gauge, output_html]
688
- )
689
-
690
- clear_btn.click(
691
- fn=lambda: (None, None, None, None, "<i>No analysis yet. Upload a document and click Analyze.</i>"),
692
- inputs=None,
693
- outputs=[input_file, input_webcam, output_image, metrics_gauge, output_html]
694
- )
695
 
696
 
697
- if __name__ == "__main__":
698
- demo.launch()
 
 
1
  """
2
+ Mask refinement and region extraction
3
+ Implements Critical Fix #3: Adaptive Mask Refinement Thresholds
 
4
  """
5
 
 
 
6
  import cv2
7
  import numpy as np
8
+ from typing import List, Tuple, Dict, Optional
9
+ from scipy import ndimage
10
+ from skimage.measure import label, regionprops
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ class MaskRefiner:
14
+ """
15
+ Mask refinement with adaptive thresholds
16
+ Implements Critical Fix #3: Dataset-specific minimum region areas
17
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ def __init__(self, config, dataset_name: str = 'default'):
20
+ """
21
+ Initialize mask refiner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ Args:
24
+ config: Configuration object
25
+ dataset_name: Dataset name for adaptive thresholds
26
+ """
27
+ self.config = config
28
+ self.dataset_name = dataset_name
29
 
30
+ # Get mask refinement parameters
31
+ self.threshold = config.get('mask_refinement.threshold', 0.5)
32
+ self.closing_kernel = config.get('mask_refinement.morphology.closing_kernel', 5)
33
+ self.opening_kernel = config.get('mask_refinement.morphology.opening_kernel', 3)
34
 
35
+ # Critical Fix #3: Adaptive thresholds per dataset
36
+ self.min_region_area = config.get_min_region_area(dataset_name)
 
 
 
 
37
 
38
+ print(f"MaskRefiner initialized for {dataset_name}")
39
+ print(f"Min region area: {self.min_region_area * 100:.2f}%")
40
 
41
+ def refine(self,
42
+ probability_map: np.ndarray,
43
+ original_size: Tuple[int, int] = None) -> np.ndarray:
44
  """
45
+ Refine probability map to binary mask
46
+
47
+ Args:
48
+ probability_map: Forgery probability map (H, W), values [0, 1]
49
+ original_size: Optional (H, W) to resize mask back to original
50
 
51
  Returns:
52
+ Refined binary mask (H, W)
 
 
 
 
53
  """
54
+ # Threshold to binary
55
+ binary_mask = (probability_map > self.threshold).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
+ # Morphological closing (fill broken strokes)
58
+ closing_kernel = cv2.getStructuringElement(
59
+ cv2.MORPH_RECT,
60
+ (self.closing_kernel, self.closing_kernel)
61
+ )
62
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, closing_kernel)
 
 
63
 
64
+ # Morphological opening (remove isolated noise)
65
+ opening_kernel = cv2.getStructuringElement(
66
+ cv2.MORPH_RECT,
67
+ (self.opening_kernel, self.opening_kernel)
 
68
  )
69
+ binary_mask = cv2.morphologyEx(binary_mask, cv2.MORPH_OPEN, opening_kernel)
70
 
71
+ # Critical Fix #3: Remove small regions with adaptive threshold
72
+ binary_mask = self._remove_small_regions(binary_mask)
 
73
 
74
+ # Resize to original size if provided
75
+ if original_size is not None:
76
+ binary_mask = cv2.resize(
77
+ binary_mask,
78
+ (original_size[1], original_size[0]), # cv2 uses (W, H)
79
  interpolation=cv2.INTER_NEAREST
80
  )
81
 
82
+ return binary_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
+ def _remove_small_regions(self, mask: np.ndarray) -> np.ndarray:
85
+ """
86
+ Remove regions smaller than minimum area threshold
87
 
88
+ Args:
89
+ mask: Binary mask (H, W)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
+ Returns:
92
+ Filtered mask
93
+ """
94
+ # Calculate minimum pixel count
95
+ image_area = mask.shape[0] * mask.shape[1]
96
+ min_pixels = int(image_area * self.min_region_area)
97
 
98
+ # Label connected components
99
+ labeled_mask, num_features = ndimage.label(mask)
 
 
 
 
 
100
 
101
+ # Keep only large enough regions
102
+ filtered_mask = np.zeros_like(mask)
 
 
 
 
103
 
104
+ for region_id in range(1, num_features + 1):
105
+ region_mask = (labeled_mask == region_id)
106
+ region_area = region_mask.sum()
 
 
 
 
107
 
108
+ if region_area >= min_pixels:
109
+ filtered_mask[region_mask] = 1
110
 
111
+ return filtered_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
 
114
+ class RegionExtractor:
115
+ """
116
+ Extract individual regions from binary mask
117
+ Implements Critical Fix #4: Region Confidence Aggregation
118
+ """
119
+
120
+ def __init__(self, config, dataset_name: str = 'default'):
121
+ """
122
+ Initialize region extractor
123
 
124
+ Args:
125
+ config: Configuration object
126
+ dataset_name: Dataset name
127
+ """
128
+ self.config = config
129
+ self.dataset_name = dataset_name
130
+ self.min_region_area = config.get_min_region_area(dataset_name)
131
+
132
+ def extract(self,
133
+ binary_mask: np.ndarray,
134
+ probability_map: np.ndarray,
135
+ original_image: np.ndarray) -> List[Dict]:
136
+ """
137
+ Extract regions from binary mask
138
 
139
+ Args:
140
+ binary_mask: Refined binary mask (H, W)
141
+ probability_map: Original probability map (H, W)
142
+ original_image: Original image (H, W, 3)
143
 
144
+ Returns:
145
+ List of region dictionaries with bounding box, mask, image, confidence
 
 
 
 
 
 
 
 
146
  """
147
+ regions = []
148
+
149
+ print(f"[REGION_EXTRACT] Input shapes:")
150
+ print(f" - binary_mask: {binary_mask.shape}")
151
+ print(f" - probability_map: {probability_map.shape}")
152
+ print(f" - original_image: {original_image.shape}")
153
+
154
+ # Safety check: Ensure probability_map and binary_mask have same dimensions
155
+ if probability_map.shape != binary_mask.shape:
156
+ print(f"[REGION_EXTRACT] WARNING: Shape mismatch! Resizing probability_map from {probability_map.shape} to {binary_mask.shape}")
157
+ import cv2
158
+ probability_map = cv2.resize(
159
+ probability_map,
160
+ (binary_mask.shape[1], binary_mask.shape[0]),
161
+ interpolation=cv2.INTER_LINEAR
162
+ )
163
+ print(f"[REGION_EXTRACT] After resize: probability_map shape = {probability_map.shape}")
164
+
165
+ # Connected component analysis (8-connectivity)
166
+ labeled_mask = label(binary_mask, connectivity=2)
167
+ props = regionprops(labeled_mask)
168
+
169
+ for region_id, prop in enumerate(props, start=1):
170
+ # Bounding box
171
+ y_min, x_min, y_max, x_max = prop.bbox
172
+
173
+ # Region mask
174
+ region_mask = (labeled_mask == region_id).astype(np.uint8)
175
+
176
+ # Cropped region image
177
+ region_image = original_image[y_min:y_max, x_min:x_max].copy()
178
+ region_mask_cropped = region_mask[y_min:y_max, x_min:x_max]
179
+
180
+
181
+ # Critical Fix #4: Region-level confidence aggregation
182
+ # Ensure region_mask and probability_map have same shape
183
+ if region_mask.shape != probability_map.shape:
184
+ import cv2
185
+ # Resize probability_map to match region_mask
186
+ probability_map = cv2.resize(
187
+ probability_map,
188
+ (region_mask.shape[1], region_mask.shape[0]),
189
+ interpolation=cv2.INTER_LINEAR
190
+ )
191
+
192
+ region_probs = probability_map[region_mask > 0]
193
+ region_confidence = float(np.mean(region_probs)) if len(region_probs) > 0 else 0.0
194
+
195
+ regions.append({
196
+ 'region_id': region_id,
197
+ 'bounding_box': [int(x_min), int(y_min),
198
+ int(x_max - x_min), int(y_max - y_min)],
199
+ 'area': prop.area,
200
+ 'centroid': (int(prop.centroid[1]), int(prop.centroid[0])),
201
+ 'region_mask': region_mask,
202
+ 'region_mask_cropped': region_mask_cropped,
203
+ 'region_image': region_image,
204
+ 'confidence': region_confidence,
205
+ 'mask_probability_mean': region_confidence
206
+ })
207
+
208
+ return regions
209
+
210
+ def extract_for_casia(self,
211
+ binary_mask: np.ndarray,
212
+ probability_map: np.ndarray,
213
+ original_image: np.ndarray) -> List[Dict]:
214
  """
215
+ Critical Fix #6: CASIA handling - treat entire image as one region
216
+
217
+ Args:
218
+ binary_mask: Binary mask (may be empty for authentic images)
219
+ probability_map: Probability map
220
+ original_image: Original image
221
+
222
+ Returns:
223
+ Single region representing entire image
224
  """
225
+ h, w = original_image.shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ # Create single region covering entire image
228
+ region_mask = np.ones((h, w), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ # Overall confidence from probability map
231
+ overall_confidence = float(np.mean(probability_map))
 
 
 
 
 
 
 
 
 
 
232
 
233
+ return [{
234
+ 'region_id': 1,
235
+ 'bounding_box': [0, 0, w, h],
236
+ 'area': h * w,
237
+ 'centroid': (w // 2, h // 2),
238
+ 'region_mask': region_mask,
239
+ 'region_mask_cropped': region_mask,
240
+ 'region_image': original_image,
241
+ 'confidence': overall_confidence,
242
+ 'mask_probability_mean': overall_confidence
243
+ }]
 
 
 
 
 
 
244
 
245
+
246
+ def get_mask_refiner(config, dataset_name: str = 'default') -> MaskRefiner:
247
+ """Factory function for mask refiner"""
248
+ return MaskRefiner(config, dataset_name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
+ def get_region_extractor(config, dataset_name: str = 'default') -> RegionExtractor:
252
+ """Factory function for region extractor"""
253
+ return RegionExtractor(config, dataset_name)