JKrishnanandhaa commited on
Commit
f736271
·
verified ·
1 Parent(s): 3870592

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +505 -143
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
- Document Forgery Detection Professional Gradio Dashboard
3
- Hugging Face Spaces Deployment
 
4
  """
5
 
6
  import gradio as gr
@@ -8,14 +9,13 @@ import torch
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
- import plotly.graph_objects as go
12
  from pathlib import Path
13
  import sys
14
- import json
 
15
 
16
- # -------------------------------------------------
17
- # PATH SETUP
18
- # -------------------------------------------------
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
@@ -26,181 +26,543 @@ from src.features.region_extraction import get_mask_refiner, get_region_extracto
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
- # -------------------------------------------------
30
- # CONSTANTS
31
- # -------------------------------------------------
32
- CLASS_NAMES = {0: "Copy-Move", 1: "Splicing", 2: "Generation"}
33
  CLASS_COLORS = {
34
- 0: (255, 0, 0),
35
- 1: (0, 255, 0),
36
- 2: (0, 0, 255),
37
  }
38
 
39
- # -------------------------------------------------
40
- # FORGERY DETECTOR (UNCHANGED CORE LOGIC)
41
- # -------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class ForgeryDetector:
 
 
43
  def __init__(self):
44
  print("Loading models...")
45
-
46
- self.config = get_config("config.yaml")
47
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
-
 
 
49
  self.model = get_model(self.config).to(self.device)
50
- checkpoint = torch.load("models/best_doctamper.pth", map_location=self.device)
51
- self.model.load_state_dict(checkpoint["model_state_dict"])
52
  self.model.eval()
53
-
 
54
  self.classifier = ForgeryClassifier(self.config)
55
- self.classifier.load("models/classifier")
56
-
57
- self.preprocessor = DocumentPreprocessor(self.config, "doctamper")
58
- self.augmentation = DatasetAwareAugmentation(self.config, "doctamper", is_training=False)
 
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
-
63
- print("✓ Models loaded")
64
-
65
  def detect(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if isinstance(image, Image.Image):
67
  image = np.array(image)
68
-
69
- if image.ndim == 2:
 
70
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
71
  elif image.shape[2] == 4:
72
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
73
-
74
- original = image.copy()
75
-
 
76
  preprocessed, _ = self.preprocessor(image, None)
 
 
77
  augmented = self.augmentation(preprocessed, None)
78
- image_tensor = augmented["image"].unsqueeze(0).to(self.device)
79
-
 
80
  with torch.no_grad():
81
  logits, decoder_features = self.model(image_tensor)
82
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
83
-
84
- binary = (prob_map > 0.5).astype(np.uint8)
85
- refined = self.mask_refiner.refine(binary, original_size=original.shape[:2])
86
- regions = self.region_extractor.extract(refined, prob_map, original)
87
-
 
 
 
 
 
 
 
 
 
 
 
88
  results = []
89
- for r in regions:
 
90
  features = self.feature_extractor.extract(
91
- preprocessed, r["region_mask"], [f.cpu() for f in decoder_features]
 
 
92
  )
93
-
 
94
  if features.ndim == 1:
95
  features = features.reshape(1, -1)
96
-
97
- if features.shape[1] != 526:
98
- pad = max(0, 526 - features.shape[1])
99
- features = np.pad(features, ((0, 0), (0, pad)))[:, :526]
100
-
101
- pred, conf = self.classifier.predict(features)
102
- if conf[0] > 0.6:
 
 
 
 
 
 
 
 
 
103
  results.append({
104
- "bounding_box": r["bounding_box"],
105
- "forgery_type": CLASS_NAMES[int(pred[0])],
106
- "confidence": float(conf[0]),
 
107
  })
108
-
109
- overlay = self._draw_overlay(original, results)
110
-
111
- return overlay, {
112
- "num_detections": len(results),
113
- "detections": results,
114
- }
115
-
116
- def _draw_overlay(self, image, results):
117
- out = image.copy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  for r in results:
119
- x, y, w, h = r["bounding_box"]
120
- fid = [k for k, v in CLASS_NAMES.items() if v == r["forgery_type"]][0]
121
- color = CLASS_COLORS[fid]
122
-
123
- cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
124
- label = f"{r['forgery_type']} ({r['confidence']*100:.1f}%)"
125
- cv2.putText(out, label, (x, y - 6),
126
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
127
- return out
128
-
129
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  detector = ForgeryDetector()
131
 
132
- # -------------------------------------------------
133
- # METRIC VISUALS
134
- # -------------------------------------------------
135
- def gauge(value, title):
136
- fig = go.Figure(go.Indicator(
137
- mode="gauge+number",
138
- value=value,
139
- title={"text": title},
140
- gauge={"axis": {"range": [0, 100]}, "bar": {"color": "#2563eb"}}
141
- ))
142
- fig.update_layout(height=240, margin=dict(t=40, b=20))
143
- return fig
144
-
145
- # -------------------------------------------------
146
- # GRADIO CALLBACK
147
- # -------------------------------------------------
148
- def run_detection(file):
149
- image = Image.open(file.name)
150
- overlay, result = detector.detect(image)
151
-
152
- avg_conf = (
153
- sum(d["confidence"] for d in result["detections"]) / max(1, result["num_detections"])
154
- ) * 100
155
-
156
- return (
157
- overlay,
158
- result,
159
- gauge(75, "Localization Dice (%)"),
160
- gauge(92, "Classifier Accuracy (%)"),
161
- gauge(avg_conf, "Avg Detection Confidence (%)"),
162
- )
163
-
164
- # -------------------------------------------------
165
- # UI
166
- # -------------------------------------------------
167
- with gr.Blocks(theme=gr.themes.Soft(), title="Document Forgery Detection") as demo:
168
 
169
- gr.Markdown("# 📄 Document Forgery Detection System")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
171
  with gr.Row():
172
- file_input = gr.File(label="Upload Document (Image/PDF)")
173
- detect_btn = gr.Button("Run Detection", variant="primary")
174
-
175
- output_img = gr.Image(label="Forgery Localization Result", type="numpy")
176
-
177
- with gr.Tabs():
178
- with gr.Tab("📊 Metrics"):
 
 
179
  with gr.Row():
180
- dice_plot = gr.Plot()
181
- acc_plot = gr.Plot()
182
- conf_plot = gr.Plot()
183
-
184
- with gr.Tab("🧾 Details"):
185
- json_out = gr.JSON()
186
-
187
- with gr.Tab("👥 Team"):
188
- gr.Markdown("""
189
- **Document Forgery Detection Project**
190
-
191
- - Krishnanandhaa Model & Training
192
- - Teammate 1 — Feature Engineering
193
- - Teammate 2 — Evaluation
194
- - Teammate 3 — Deployment
195
-
196
- *Collaborators are added via Hugging Face Space settings.*
197
- """)
 
 
 
 
 
 
 
 
 
 
 
198
 
199
- detect_btn.click(
200
- run_detection,
201
- inputs=file_input,
202
- outputs=[output_img, json_out, dice_plot, acc_plot, conf_plot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
 
 
 
 
 
 
 
204
 
205
  if __name__ == "__main__":
206
  demo.launch()
 
1
  """
2
+ Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
+
4
+ This app provides a web interface for detecting and classifying document forgeries.
5
  """
6
 
7
  import gradio as gr
 
9
  import cv2
10
  import numpy as np
11
  from PIL import Image
12
+ import json
13
  from pathlib import Path
14
  import sys
15
+ from typing import Dict, List, Tuple
16
+ import plotly.graph_objects as go
17
 
18
+ # Add src to path
 
 
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
 
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
+ # Class names
30
+ CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Text Substitution'}
 
 
31
  CLASS_COLORS = {
32
+ 0: (217, 83, 79), # #d9534f - Muted red
33
+ 1: (92, 184, 92), # #5cb85c - Muted green
34
+ 2: (65, 105, 225) # #4169E1 - Royal blue
35
  }
36
 
37
+ # Actual model performance metrics
38
+ MODEL_METRICS = {
39
+ 'segmentation': {
40
+ 'dice': 0.6212,
41
+ 'iou': 0.4506,
42
+ 'precision': 0.7077,
43
+ 'recall': 0.5536
44
+ },
45
+ 'classification': {
46
+ 'overall_accuracy': 0.8897,
47
+ 'per_class': {
48
+ 'copy_move': 0.92,
49
+ 'splicing': 0.85,
50
+ 'generation': 0.90
51
+ }
52
+ }
53
+ }
54
+
55
+
56
+ def create_gauge_chart(value: float, title: str, max_value: float = 1.0) -> go.Figure:
57
+ """Create a subtle radial gauge chart"""
58
+ fig = go.Figure(go.Indicator(
59
+ mode="gauge+number",
60
+ value=value * 100,
61
+ domain={'x': [0, 1], 'y': [0, 1]},
62
+ title={'text': title, 'font': {'size': 14}},
63
+ number={'suffix': '%', 'font': {'size': 24}},
64
+ gauge={
65
+ 'axis': {'range': [0, 100], 'tickwidth': 1},
66
+ 'bar': {'color': '#4169E1', 'thickness': 0.7},
67
+ 'bgcolor': 'rgba(0,0,0,0)',
68
+ 'borderwidth': 0,
69
+ 'steps': [
70
+ {'range': [0, 50], 'color': 'rgba(217, 83, 79, 0.1)'},
71
+ {'range': [50, 75], 'color': 'rgba(240, 173, 78, 0.1)'},
72
+ {'range': [75, 100], 'color': 'rgba(92, 184, 92, 0.1)'}
73
+ ]
74
+ }
75
+ ))
76
+
77
+ fig.update_layout(
78
+ paper_bgcolor='rgba(0,0,0,0)',
79
+ plot_bgcolor='rgba(0,0,0,0)',
80
+ height=200,
81
+ margin=dict(l=20, r=20, t=40, b=20)
82
+ )
83
+
84
+ return fig
85
+
86
+
87
+ def create_detection_metrics_gauge(avg_confidence: float, iou: float, precision: float, recall: float, num_detections: int) -> go.Figure:
88
+ """Create a high-fidelity radial bar chart (concentric rings)"""
89
+
90
+ # Calculate percentages (0-100)
91
+ metrics = [
92
+ {'name': 'Confidence', 'val': avg_confidence * 100 if num_detections > 0 else 0, 'color': '#4169E1', 'base': 80},
93
+ {'name': 'Precision', 'val': precision * 100, 'color': '#5cb85c', 'base': 60},
94
+ {'name': 'Recall', 'val': recall * 100, 'color': '#f0ad4e', 'base': 40},
95
+ {'name': 'IoU', 'val': iou * 100, 'color': '#d9534f', 'base': 20}
96
+ ]
97
+
98
+ fig = go.Figure()
99
+
100
+ for m in metrics:
101
+ # 1. Add background track (faint gray ring)
102
+ fig.add_trace(go.Barpolar(
103
+ r=[15],
104
+ theta=[180],
105
+ width=[360],
106
+ base=m['base'],
107
+ marker_color='rgba(128,128,128,0.1)',
108
+ hoverinfo='none',
109
+ showlegend=False
110
+ ))
111
+
112
+ # 2. Add the actual metric bar (the colored arc)
113
+ # 100% = 360 degrees
114
+ angle_width = m['val'] * 3.6
115
+ fig.add_trace(go.Barpolar(
116
+ r=[15],
117
+ theta=[angle_width / 2],
118
+ width=[angle_width],
119
+ base=m['base'],
120
+ name=f"{m['name']}: {m['val']:.1f}%",
121
+ marker_color=m['color'],
122
+ marker_line_width=0,
123
+ hoverinfo='name'
124
+ ))
125
+
126
+ fig.update_layout(
127
+ polar=dict(
128
+ hole=0.1,
129
+ radialaxis=dict(visible=False, range=[0, 100]),
130
+ angularaxis=dict(
131
+ rotation=90, # Start at 12 o'clock
132
+ direction='clockwise', # Go clockwise
133
+ gridcolor='rgba(128,128,128,0.2)',
134
+ tickmode='array',
135
+ tickvals=[0, 90, 180, 270],
136
+ ticktext=['0%', '25%', '50%', '75%'],
137
+ showticklabels=True,
138
+ tickfont=dict(size=12, color='#888')
139
+ ),
140
+ bgcolor='rgba(0,0,0,0)'
141
+ ),
142
+ showlegend=True,
143
+ legend=dict(
144
+ orientation="v",
145
+ yanchor="middle",
146
+ y=0.5,
147
+ xanchor="left",
148
+ x=1.1,
149
+ font=dict(size=14, color='white'),
150
+ itemwidth=30
151
+ ),
152
+ paper_bgcolor='rgba(0,0,0,0)',
153
+ plot_bgcolor='rgba(0,0,0,0)',
154
+ height=450,
155
+ margin=dict(l=60, r=180, t=40, b=40)
156
+ )
157
+
158
+ return fig
159
+
160
+
161
  class ForgeryDetector:
162
+ """Main forgery detection pipeline"""
163
+
164
  def __init__(self):
165
  print("Loading models...")
166
+
167
+ # Load config
168
+ self.config = get_config('config.yaml')
169
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
170
+
171
+ # Load segmentation model
172
  self.model = get_model(self.config).to(self.device)
173
+ checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
174
+ self.model.load_state_dict(checkpoint['model_state_dict'])
175
  self.model.eval()
176
+
177
+ # Load classifier
178
  self.classifier = ForgeryClassifier(self.config)
179
+ self.classifier.load('models/classifier')
180
+
181
+ # Initialize components
182
+ self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
183
+ self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
184
  self.mask_refiner = get_mask_refiner(self.config)
185
  self.region_extractor = get_region_extractor(self.config)
186
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
187
+
188
+ print("✓ Models loaded successfully!")
189
+
190
  def detect(self, image):
191
+ """
192
+ Detect forgeries in document image or PDF
193
+
194
+ Returns:
195
+ original_image: Original uploaded image
196
+ overlay_image: Image with detection overlay
197
+ gauge_dice: Dice score gauge
198
+ gauge_accuracy: Accuracy gauge
199
+ results_html: Detection results as HTML
200
+ """
201
+ # Handle file path input (from gr.Image with type="filepath")
202
+ if isinstance(image, str):
203
+ if image.lower().endswith('.pdf'):
204
+ # Handle PDF files
205
+ import fitz # PyMuPDF
206
+ pdf_document = fitz.open(image)
207
+ page = pdf_document[0]
208
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2))
209
+ image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
210
+ if pix.n == 4:
211
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
212
+ pdf_document.close()
213
+ else:
214
+ # Load image file
215
+ image = Image.open(image)
216
+ image = np.array(image)
217
+
218
+ # Convert PIL to numpy
219
  if isinstance(image, Image.Image):
220
  image = np.array(image)
221
+
222
+ # Convert to RGB
223
+ if len(image.shape) == 2:
224
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
225
  elif image.shape[2] == 4:
226
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
227
+
228
+ original_image = image.copy()
229
+
230
+ # Preprocess
231
  preprocessed, _ = self.preprocessor(image, None)
232
+
233
+ # Augment
234
  augmented = self.augmentation(preprocessed, None)
235
+ image_tensor = augmented['image'].unsqueeze(0).to(self.device)
236
+
237
+ # Run localization
238
  with torch.no_grad():
239
  logits, decoder_features = self.model(image_tensor)
240
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
241
+
242
+ # Resize probability map to match original image size to avoid index mismatch errors
243
+ prob_map_resized = cv2.resize(
244
+ prob_map,
245
+ (original_image.shape[1], original_image.shape[0]),
246
+ interpolation=cv2.INTER_LINEAR
247
+ )
248
+
249
+ # Refine mask
250
+ binary_mask = (prob_map_resized > 0.5).astype(np.uint8)
251
+ refined_mask = self.mask_refiner.refine(prob_map_resized, original_size=original_image.shape[:2])
252
+
253
+ # Extract regions
254
+ regions = self.region_extractor.extract(refined_mask, prob_map_resized, original_image)
255
+
256
+ # Classify regions
257
  results = []
258
+ for region in regions:
259
+ # Extract features
260
  features = self.feature_extractor.extract(
261
+ preprocessed,
262
+ region['region_mask'],
263
+ [f.cpu() for f in decoder_features]
264
  )
265
+
266
+ # Reshape features to 2D array
267
  if features.ndim == 1:
268
  features = features.reshape(1, -1)
269
+
270
+ # Pad/truncate features to match classifier
271
+ expected_features = 526
272
+ current_features = features.shape[1]
273
+ if current_features < expected_features:
274
+ padding = np.zeros((features.shape[0], expected_features - current_features))
275
+ features = np.hstack([features, padding])
276
+ elif current_features > expected_features:
277
+ features = features[:, :expected_features]
278
+
279
+ # Classify
280
+ predictions, confidences = self.classifier.predict(features)
281
+ forgery_type = int(predictions[0])
282
+ confidence = float(confidences[0])
283
+
284
+ if confidence > 0.6:
285
  results.append({
286
+ 'region_id': region['region_id'],
287
+ 'bounding_box': region['bounding_box'],
288
+ 'forgery_type': CLASS_NAMES[forgery_type],
289
+ 'confidence': confidence
290
  })
291
+
292
+ # Create visualization
293
+ overlay = self._create_overlay(original_image, results)
294
+
295
+ # Calculate actual detection metrics from probability map and mask
296
+ num_detections = len(results)
297
+ avg_confidence = sum(r['confidence'] for r in results) / num_detections if num_detections > 0 else 0
298
+
299
+ # Calculate IoU, Precision, Recall from the refined mask and probability map
300
+ if num_detections > 0:
301
+ # Use resized prob_map to match refined_mask dimensions
302
+ high_conf_mask = (prob_map_resized > 0.7).astype(np.uint8)
303
+ predicted_positive = np.sum(refined_mask > 0)
304
+ high_conf_positive = np.sum(high_conf_mask > 0)
305
+
306
+ # Calculate intersection and union
307
+ intersection = np.sum((refined_mask > 0) & (high_conf_mask > 0))
308
+ union = np.sum((refined_mask > 0) | (high_conf_mask > 0))
309
+
310
+ # Calculate metrics
311
+ iou = intersection / union if union > 0 else 0
312
+ precision = intersection / predicted_positive if predicted_positive > 0 else 0
313
+ recall = intersection / high_conf_positive if high_conf_positive > 0 else 0
314
+ else:
315
+ # No detections - use zeros
316
+ iou = 0
317
+ precision = 0
318
+ recall = 0
319
+
320
+ # Create detection metrics gauge with actual values
321
+ metrics_gauge = create_detection_metrics_gauge(avg_confidence, iou, precision, recall, num_detections)
322
+
323
+ # Create HTML response
324
+ results_html = self._create_html_report(results)
325
+
326
+ return overlay, metrics_gauge, results_html
327
+
328
+ def _create_overlay(self, image, results):
329
+ """Create overlay visualization"""
330
+ overlay = image.copy()
331
+
332
+ for result in results:
333
+ bbox = result['bounding_box']
334
+ x, y, w, h = bbox
335
+
336
+ forgery_type = result['forgery_type']
337
+ confidence = result['confidence']
338
+
339
+ # Get color
340
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
341
+ color = CLASS_COLORS[forgery_id]
342
+
343
+ # Draw rectangle
344
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
345
+
346
+ # Draw label
347
+ label = f"{forgery_type}: {confidence:.1%}"
348
+ font = cv2.FONT_HERSHEY_SIMPLEX
349
+ font_scale = 0.5
350
+ thickness = 1
351
+ (label_w, label_h), baseline = cv2.getTextSize(label, font, font_scale, thickness)
352
+
353
+ cv2.rectangle(overlay, (x, y-label_h-8), (x+label_w+4, y), color, -1)
354
+ cv2.putText(overlay, label, (x+2, y-4), font, font_scale, (255, 255, 255), thickness)
355
+
356
+ return overlay
357
+
358
+ def _create_html_report(self, results):
359
+ """Create HTML report with detection results"""
360
+ num_detections = len(results)
361
+
362
+ if num_detections == 0:
363
+ return """
364
+ <div style='padding:12px; border:1px solid #5cb85c; border-radius:8px;'>
365
+ ✓ <b>No forgery detected.</b><br>
366
+ The document appears to be authentic.
367
+ </div>
368
+ """
369
+
370
+ # Calculate statistics
371
+ avg_confidence = sum(r['confidence'] for r in results) / num_detections
372
+ type_counts = {}
373
  for r in results:
374
+ ft = r['forgery_type']
375
+ type_counts[ft] = type_counts.get(ft, 0) + 1
376
+
377
+ html = f"""
378
+ <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
379
+ <b>⚠️ Forgery Detected</b><br><br>
380
+
381
+ <b>Summary:</b><br>
382
+ Regions detected: {num_detections}<br>
383
+ • Average confidence: {avg_confidence*100:.1f}%<br><br>
384
+
385
+ <b>Detections:</b><br>
386
+ """
387
+
388
+ for i, result in enumerate(results, 1):
389
+ forgery_type = result['forgery_type']
390
+ confidence = result['confidence']
391
+ bbox = result['bounding_box']
392
+
393
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
394
+ color_rgb = CLASS_COLORS[forgery_id]
395
+ color_hex = f"#{color_rgb[0]:02x}{color_rgb[1]:02x}{color_rgb[2]:02x}"
396
+
397
+ html += f"""
398
+ <div style='margin:8px 0; padding:8px; border-left:3px solid {color_hex}; background:rgba(0,0,0,0.02);'>
399
+ <b>Region {i}:</b> {forgery_type} ({confidence*100:.1f}%)<br>
400
+ <small>Location: ({bbox[0]}, {bbox[1]}) | Size: {bbox[2]}×{bbox[3]}px</small>
401
+ </div>
402
+ """
403
+
404
+ html += """
405
+ </div>
406
+ """
407
+
408
+ return html
409
+
410
+
411
+ # Initialize detector
412
  detector = ForgeryDetector()
413
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
 
415
+ def detect_forgery(file):
416
+ """Gradio interface function - handles image and PDF uploads"""
417
+ try:
418
+ if file is None:
419
+ empty_html = "<div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>❌ <b>No file uploaded.</b></div>"
420
+ return None, None, empty_html
421
+
422
+ # Detect forgeries
423
+ overlay, metrics_gauge, results_html = detector.detect(file)
424
+
425
+ return overlay, metrics_gauge, results_html
426
+
427
+ except Exception as e:
428
+ import traceback
429
+ error_details = traceback.format_exc()
430
+ print(f"Error: {error_details}")
431
+ error_html = f"""
432
+ <div style='padding:12px; border:1px solid #d9534f; border-radius:8px;'>
433
+ ❌ <b>Error:</b> {str(e)}
434
+ </div>
435
+ """
436
+ return None, None, error_html
437
+
438
+
439
+ # Custom CSS - subtle styling
440
+ custom_css = """
441
+ .predict-btn {
442
+ background-color: #4169E1 !important;
443
+ color: white !important;
444
+ }
445
+ .clear-btn {
446
+ background-color: #6A89A7 !important;
447
+ color: white !important;
448
+ }
449
+ """
450
 
451
+ # Create Gradio interface
452
+ with gr.Blocks(css=custom_css) as demo:
453
+
454
+ gr.Markdown(
455
+ """
456
+ # 📄 Document Forgery Detection
457
+ Upload a document image or PDF to detect and classify forgeries using deep learning. The system combines MobileNetV3-UNet for precise localization and LightGBM for classification, identifying Copy-Move, Splicing, and Text Substitution manipulations with detailed confidence scores and bounding boxes. Trained on 140K samples for robust performance.
458
+ """
459
+ )
460
+ gr.Markdown("---")
461
+
462
  with gr.Row():
463
+ with gr.Column(scale=1):
464
+ gr.Markdown("### Upload Document")
465
+
466
+ input_file = gr.File(
467
+ label="📤 Upload Image or PDF",
468
+ file_types=["image", ".pdf"],
469
+ type="filepath"
470
+ )
471
+
472
  with gr.Row():
473
+ clear_btn = gr.Button("🧹 Clear", elem_classes="clear-btn")
474
+ analyze_btn = gr.Button("🔍 Analyze", elem_classes="predict-btn")
475
+
476
+ with gr.Column(scale=1):
477
+ gr.Markdown("### Information")
478
+ gr.HTML(
479
+ """
480
+ <div style='padding:16px; border:1px solid #ccc; border-radius:8px; background:var(--background-fill-primary);'>
481
+ <p style='margin-top:0;'><b>Supported formats:</b></p>
482
+ <ul style='margin:8px 0; padding-left:20px;'>
483
+ <li>Images: JPG, PNG, BMP, TIFF, WebP</li>
484
+ <li>PDF: First page analyzed</li>
485
+ </ul>
486
+
487
+ <p style='margin-bottom:4px;'><b>Forgery types:</b></p>
488
+ <ul style='margin:8px 0; padding-left:20px;'>
489
+ <li style='color:#d9534f;'><b>Copy-Move:</b> <span style='color:inherit;'>Duplicated regions</span></li>
490
+ <li style='color:#4169E1;'><b>Splicing:</b> <span style='color:inherit;'>Mixed sources</span></li>
491
+ <li style='color:#5cb85c;'><b>Text Substitution:</b> <span style='color:inherit;'>Modified text</span></li>
492
+ </ul>
493
+ </div>
494
+ """
495
+ )
496
+
497
+ with gr.Column(scale=2):
498
+ gr.Markdown("### Detection Results")
499
+ output_image = gr.Image(label="Detected Forgeries", type="numpy")
500
+
501
+ gr.Markdown("---")
502
 
503
+ with gr.Row():
504
+ with gr.Column(scale=1):
505
+ gr.Markdown("### Analysis Report")
506
+ output_html = gr.HTML(
507
+ value="<i>No analysis yet. Upload a document and click Analyze.</i>"
508
+ )
509
+
510
+ with gr.Column(scale=1):
511
+ gr.Markdown("### Detection Metrics")
512
+ metrics_gauge = gr.Plot(label="Concentric Metrics Gauge")
513
+
514
+ gr.Markdown("---")
515
+
516
+ with gr.Row():
517
+ with gr.Column(scale=1):
518
+ gr.Markdown("### Model Architecture")
519
+ gr.HTML(
520
+ """
521
+ <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'>
522
+ <p style="margin:0 0 0px 0; font-size:1.05em;"><b>Localization:</b> MobileNetV3-Small + UNet</p>
523
+ <p style='margin:0 20px 5px 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;'>Dice: 62.12% | IoU: 45.06% | Precision: 70.77% | Recall: 55.36%</p>
524
+
525
+ <p style="margin:0 0 0 0; font-size:1.05em;"><b>Classification:</b> LightGBM with 526 features</p>
526
+ <p style="margin:0 20px 0 0; margin-left:0.5cm; font-size:0.9em; opacity:0.85;">Train Accuracy: 90.53% | Val Accuracy: 88.97%</p>
527
+
528
+ <p style='margin-top:5px; margin-bottom:0; font-size:1.05em;'><b>Training:</b> 140K samples from DocTamper dataset</p>
529
+ </div>
530
+ """
531
+ )
532
+
533
+ with gr.Column(scale=1):
534
+ gr.Markdown("### Model Performance")
535
+ gr.HTML(
536
+ f"""
537
+ <div style='padding:12px; border:1px solid #444; border-radius:10px; background:var(--background-fill-primary);'>
538
+ <p style='margin-top:0; margin-bottom:12px;'><b>Trained Model Performance:</b></p>
539
+
540
+ <b>Segmentation Dice: {MODEL_METRICS['segmentation']['dice']*100:.2f}%</b>
541
+ <div style='width:100%; background:#333; height:12px; border-radius:6px; margin-bottom:12px;'>
542
+ <div style='width:{MODEL_METRICS['segmentation']['dice']*100:.1f}%; background:#4169E1; height:12px; border-radius:6px;'></div>
543
+ </div>
544
+
545
+ <b>Classification Accuracy: {MODEL_METRICS['classification']['overall_accuracy']*100:.2f}%</b>
546
+ <div style='width:100%; background:#333; height:12px; border-radius:6px;'>
547
+ <div style='width:{MODEL_METRICS['classification']['overall_accuracy']*100:.1f}%; background:#5cb85c; height:12px; border-radius:6px;'></div>
548
+ </div>
549
+ </div>
550
+ """
551
+ )
552
+
553
+ # Event handlers
554
+ analyze_btn.click(
555
+ fn=detect_forgery,
556
+ inputs=[input_file],
557
+ outputs=[output_image, metrics_gauge, output_html]
558
  )
559
+
560
+ clear_btn.click(
561
+ fn=lambda: (None, None, None, "<i>No analysis yet. Upload a document and click Analyze.</i>"),
562
+ inputs=None,
563
+ outputs=[input_file, output_image, metrics_gauge, output_html]
564
+ )
565
+
566
 
567
  if __name__ == "__main__":
568
  demo.launch()