JKrishnanandhaa commited on
Commit
5b33d5d
Β·
verified Β·
1 Parent(s): b606a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -148
app.py CHANGED
@@ -1,6 +1,7 @@
1
  """
2
- Document Forgery Detection – Professional Gradio Dashboard
3
- Hugging Face Spaces Deployment
 
4
  """
5
 
6
  import gradio as gr
@@ -8,14 +9,11 @@ import torch
8
  import cv2
9
  import numpy as np
10
  from PIL import Image
11
- import plotly.graph_objects as go
12
  from pathlib import Path
13
  import sys
14
- import json
15
 
16
- # -------------------------------------------------
17
- # PATH SETUP
18
- # -------------------------------------------------
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
  from src.models import get_model
@@ -26,181 +24,253 @@ from src.features.region_extraction import get_mask_refiner, get_region_extracto
26
  from src.features.feature_extraction import get_feature_extractor
27
  from src.training.classifier import ForgeryClassifier
28
 
29
- # -------------------------------------------------
30
- # CONSTANTS
31
- # -------------------------------------------------
32
- CLASS_NAMES = {0: "Copy-Move", 1: "Splicing", 2: "Generation"}
33
  CLASS_COLORS = {
34
- 0: (255, 0, 0),
35
- 1: (0, 255, 0),
36
- 2: (0, 0, 255),
37
  }
38
 
39
- # -------------------------------------------------
40
- # FORGERY DETECTOR (UNCHANGED CORE LOGIC)
41
- # -------------------------------------------------
42
  class ForgeryDetector:
 
 
43
  def __init__(self):
44
  print("Loading models...")
45
-
46
- self.config = get_config("config.yaml")
47
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
-
 
 
49
  self.model = get_model(self.config).to(self.device)
50
- checkpoint = torch.load("models/best_doctamper.pth", map_location=self.device)
51
- self.model.load_state_dict(checkpoint["model_state_dict"])
52
  self.model.eval()
53
-
 
54
  self.classifier = ForgeryClassifier(self.config)
55
- self.classifier.load("models/classifier")
56
-
57
- self.preprocessor = DocumentPreprocessor(self.config, "doctamper")
58
- self.augmentation = DatasetAwareAugmentation(self.config, "doctamper", is_training=False)
 
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
-
63
- print("βœ“ Models loaded")
64
-
65
  def detect(self, image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  if isinstance(image, Image.Image):
67
  image = np.array(image)
68
-
69
- if image.ndim == 2:
 
70
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
71
  elif image.shape[2] == 4:
72
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
73
-
74
- original = image.copy()
75
-
 
76
  preprocessed, _ = self.preprocessor(image, None)
 
 
77
  augmented = self.augmentation(preprocessed, None)
78
- image_tensor = augmented["image"].unsqueeze(0).to(self.device)
79
-
 
80
  with torch.no_grad():
81
  logits, decoder_features = self.model(image_tensor)
82
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
83
-
84
- binary = (prob_map > 0.5).astype(np.uint8)
85
- refined = self.mask_refiner.refine(binary, original_size=original.shape[:2])
86
- regions = self.region_extractor.extract(refined, prob_map, original)
87
-
 
 
 
 
88
  results = []
89
- for r in regions:
 
90
  features = self.feature_extractor.extract(
91
- preprocessed, r["region_mask"], [f.cpu() for f in decoder_features]
 
 
92
  )
93
-
 
94
  if features.ndim == 1:
95
  features = features.reshape(1, -1)
96
-
97
- if features.shape[1] != 526:
98
- pad = max(0, 526 - features.shape[1])
99
- features = np.pad(features, ((0, 0), (0, pad)))[:, :526]
100
-
101
- pred, conf = self.classifier.predict(features)
102
- if conf[0] > 0.6:
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  results.append({
104
- "bounding_box": r["bounding_box"],
105
- "forgery_type": CLASS_NAMES[int(pred[0])],
106
- "confidence": float(conf[0]),
 
107
  })
108
-
109
- overlay = self._draw_overlay(original, results)
110
-
111
- return overlay, {
112
- "num_detections": len(results),
113
- "detections": results,
 
 
 
 
 
 
114
  }
115
-
116
- def _draw_overlay(self, image, results):
117
- out = image.copy()
118
- for r in results:
119
- x, y, w, h = r["bounding_box"]
120
- fid = [k for k, v in CLASS_NAMES.items() if v == r["forgery_type"]][0]
121
- color = CLASS_COLORS[fid]
122
-
123
- cv2.rectangle(out, (x, y), (x + w, y + h), color, 2)
124
- label = f"{r['forgery_type']} ({r['confidence']*100:.1f}%)"
125
- cv2.putText(out, label, (x, y - 6),
126
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
127
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
 
130
- detector = ForgeryDetector()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- # -------------------------------------------------
133
- # METRIC VISUALS
134
- # -------------------------------------------------
135
- def gauge(value, title):
136
- fig = go.Figure(go.Indicator(
137
- mode="gauge+number",
138
- value=value,
139
- title={"text": title},
140
- gauge={"axis": {"range": [0, 100]}, "bar": {"color": "#2563eb"}}
141
- ))
142
- fig.update_layout(height=240, margin=dict(t=40, b=20))
143
- return fig
144
-
145
- # -------------------------------------------------
146
- # GRADIO CALLBACK
147
- # -------------------------------------------------
148
- def run_detection(file):
149
- image = Image.open(file.name)
150
- overlay, result = detector.detect(image)
151
-
152
- avg_conf = (
153
- sum(d["confidence"] for d in result["detections"]) / max(1, result["num_detections"])
154
- ) * 100
155
-
156
- return (
157
- overlay,
158
- result,
159
- gauge(75, "Localization Dice (%)"),
160
- gauge(92, "Classifier Accuracy (%)"),
161
- gauge(avg_conf, "Avg Detection Confidence (%)"),
162
- )
163
-
164
- # -------------------------------------------------
165
- # UI
166
- # -------------------------------------------------
167
- with gr.Blocks(theme=gr.themes.Soft(), title="Document Forgery Detection") as demo:
168
-
169
- gr.Markdown("# πŸ“„ Document Forgery Detection System")
170
-
171
- with gr.Row():
172
- file_input = gr.File(label="Upload Document (Image/PDF)")
173
- detect_btn = gr.Button("Run Detection", variant="primary")
174
-
175
- output_img = gr.Image(label="Forgery Localization Result", type="numpy")
176
-
177
- with gr.Tabs():
178
- with gr.Tab("πŸ“Š Metrics"):
179
- with gr.Row():
180
- dice_plot = gr.Plot()
181
- acc_plot = gr.Plot()
182
- conf_plot = gr.Plot()
183
-
184
- with gr.Tab("🧾 Details"):
185
- json_out = gr.JSON()
186
-
187
- with gr.Tab("πŸ‘₯ Team"):
188
- gr.Markdown("""
189
- **Document Forgery Detection Project**
190
-
191
- - Krishnanandhaa β€” Model & Training
192
- - Teammate 1 β€” Feature Engineering
193
- - Teammate 2 β€” Evaluation
194
- - Teammate 3 β€” Deployment
195
-
196
- *Collaborators are added via Hugging Face Space settings.*
197
- """)
198
-
199
- detect_btn.click(
200
- run_detection,
201
- inputs=file_input,
202
- outputs=[output_img, json_out, dice_plot, acc_plot, conf_plot]
203
- )
204
 
205
  if __name__ == "__main__":
206
  demo.launch()
 
 
1
  """
2
+ Document Forgery Detection - Gradio Interface for Hugging Face Spaces
3
+
4
+ This app provides a web interface for detecting and classifying document forgeries.
5
  """
6
 
7
  import gradio as gr
 
9
  import cv2
10
  import numpy as np
11
  from PIL import Image
12
+ import json
13
  from pathlib import Path
14
  import sys
 
15
 
16
+ # Add src to path
 
 
17
  sys.path.insert(0, str(Path(__file__).parent))
18
 
19
  from src.models import get_model
 
24
  from src.features.feature_extraction import get_feature_extractor
25
  from src.training.classifier import ForgeryClassifier
26
 
27
+ # Class names
28
+ CLASS_NAMES = {0: 'Copy-Move', 1: 'Splicing', 2: 'Generation'}
 
 
29
  CLASS_COLORS = {
30
+ 0: (255, 0, 0), # Red for Copy-Move
31
+ 1: (0, 255, 0), # Green for Splicing
32
+ 2: (0, 0, 255) # Blue for Generation
33
  }
34
 
35
+
 
 
36
  class ForgeryDetector:
37
+ """Main forgery detection pipeline"""
38
+
39
  def __init__(self):
40
  print("Loading models...")
41
+
42
+ # Load config
43
+ self.config = get_config('config.yaml')
44
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
45
+
46
+ # Load segmentation model
47
  self.model = get_model(self.config).to(self.device)
48
+ checkpoint = torch.load('models/best_doctamper.pth', map_location=self.device)
49
+ self.model.load_state_dict(checkpoint['model_state_dict'])
50
  self.model.eval()
51
+
52
+ # Load classifier
53
  self.classifier = ForgeryClassifier(self.config)
54
+ self.classifier.load('models/classifier')
55
+
56
+ # Initialize components
57
+ self.preprocessor = DocumentPreprocessor(self.config, 'doctamper')
58
+ self.augmentation = DatasetAwareAugmentation(self.config, 'doctamper', is_training=False)
59
  self.mask_refiner = get_mask_refiner(self.config)
60
  self.region_extractor = get_region_extractor(self.config)
61
  self.feature_extractor = get_feature_extractor(self.config, is_text_document=True)
62
+
63
+ print("βœ“ Models loaded successfully!")
64
+
65
  def detect(self, image):
66
+ """
67
+ Detect forgeries in document image or PDF
68
+
69
+ Args:
70
+ image: PIL Image, numpy array, or path to PDF file
71
+
72
+ Returns:
73
+ overlay_image: Image with detection overlay
74
+ results_json: Detection results as JSON
75
+ """
76
+ # Handle PDF files
77
+ if isinstance(image, str) and image.lower().endswith('.pdf'):
78
+ import fitz # PyMuPDF
79
+ # Open PDF and convert first page to image
80
+ pdf_document = fitz.open(image)
81
+ page = pdf_document[0] # First page
82
+ pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x scale for better quality
83
+ image = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
84
+ if pix.n == 4: # RGBA
85
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
86
+ pdf_document.close()
87
+
88
+ # Convert PIL to numpy
89
  if isinstance(image, Image.Image):
90
  image = np.array(image)
91
+
92
+ # Convert to RGB
93
+ if len(image.shape) == 2:
94
  image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
95
  elif image.shape[2] == 4:
96
  image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
97
+
98
+ original_image = image.copy()
99
+
100
+ # Preprocess
101
  preprocessed, _ = self.preprocessor(image, None)
102
+
103
+ # Augment
104
  augmented = self.augmentation(preprocessed, None)
105
+ image_tensor = augmented['image'].unsqueeze(0).to(self.device)
106
+
107
+ # Run localization
108
  with torch.no_grad():
109
  logits, decoder_features = self.model(image_tensor)
110
  prob_map = torch.sigmoid(logits).cpu().numpy()[0, 0]
111
+
112
+ # Refine mask
113
+ binary_mask = (prob_map > 0.5).astype(np.uint8)
114
+ refined_mask = self.mask_refiner.refine(binary_mask, original_size=original_image.shape[:2])
115
+
116
+ # Extract regions
117
+ regions = self.region_extractor.extract(refined_mask, prob_map, original_image)
118
+
119
+ # Classify regions
120
  results = []
121
+ for region in regions:
122
+ # Extract features
123
  features = self.feature_extractor.extract(
124
+ preprocessed,
125
+ region['region_mask'],
126
+ [f.cpu() for f in decoder_features]
127
  )
128
+
129
+ # Reshape features to 2D array (1, n_features) for classifier
130
  if features.ndim == 1:
131
  features = features.reshape(1, -1)
132
+
133
+ # TEMPORARY FIX: Pad features to match classifier's expected count
134
+ expected_features = 526
135
+ current_features = features.shape[1]
136
+ if current_features < expected_features:
137
+ # Pad with zeros
138
+ padding = np.zeros((features.shape[0], expected_features - current_features))
139
+ features = np.hstack([features, padding])
140
+ print(f"Warning: Padded features from {current_features} to {expected_features}")
141
+ elif current_features > expected_features:
142
+ # Truncate
143
+ features = features[:, :expected_features]
144
+ print(f"Warning: Truncated features from {current_features} to {expected_features}")
145
+
146
+ # Classify
147
+ predictions, confidences = self.classifier.predict(features)
148
+ forgery_type = int(predictions[0])
149
+ confidence = float(confidences[0])
150
+
151
+ if confidence > 0.6: # Confidence threshold
152
  results.append({
153
+ 'region_id': region['region_id'],
154
+ 'bounding_box': region['bounding_box'],
155
+ 'forgery_type': CLASS_NAMES[forgery_type],
156
+ 'confidence': confidence
157
  })
158
+
159
+ # Create visualization
160
+ overlay = self._create_overlay(original_image, results)
161
+
162
+ # Create JSON response
163
+ json_results = {
164
+ 'num_detections': len(results),
165
+ 'detections': results,
166
+ 'model_info': {
167
+ 'segmentation_dice': '75%',
168
+ 'classifier_accuracy': '92%'
169
+ }
170
  }
171
+
172
+ return overlay, json_results
173
+
174
+ def _create_overlay(self, image, results):
175
+ """Create overlay visualization"""
176
+ overlay = image.copy()
177
+
178
+ # Draw bounding boxes and labels
179
+ for result in results:
180
+ bbox = result['bounding_box']
181
+ x, y, w, h = bbox
182
+
183
+ forgery_type = result['forgery_type']
184
+ confidence = result['confidence']
185
+
186
+ # Get color
187
+ forgery_id = [k for k, v in CLASS_NAMES.items() if v == forgery_type][0]
188
+ color = CLASS_COLORS[forgery_id]
189
+
190
+ # Draw rectangle
191
+ cv2.rectangle(overlay, (x, y), (x+w, y+h), color, 2)
192
+
193
+ # Draw label
194
+ label = f"{forgery_type}: {confidence:.1%}"
195
+ label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
196
+ cv2.rectangle(overlay, (x, y-label_size[1]-10), (x+label_size[0], y), color, -1)
197
+ cv2.putText(overlay, label, (x, y-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
198
+
199
+ # Add legend
200
+ if len(results) > 0:
201
+ legend_y = 30
202
+ cv2.putText(overlay, f"Detected {len(results)} forgery region(s)",
203
+ (10, legend_y), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
204
+
205
+ return overlay
206
+
207
+
208
+ # Initialize detector
209
+ detector = ForgeryDetector()
210
 
211
 
212
+ def detect_forgery(file):
213
+ """Gradio interface function"""
214
+ try:
215
+ if file is None:
216
+ return None, {"error": "No file uploaded"}
217
+
218
+ # Get file path
219
+ file_path = file.name if hasattr(file, 'name') else file
220
+
221
+ # Check if PDF
222
+ if file_path.lower().endswith('.pdf'):
223
+ # Pass PDF path directly to detector
224
+ overlay, results = detector.detect(file_path)
225
+ else:
226
+ # Load image and pass to detector
227
+ image = Image.open(file_path)
228
+ overlay, results = detector.detect(image)
229
+
230
+ return overlay, results # Return dict directly, not json.dumps
231
+ except Exception as e:
232
+ import traceback
233
+ error_details = traceback.format_exc()
234
+ print(f"Error: {error_details}")
235
+ return None, {"error": str(e), "details": error_details}
236
+
237
+
238
+ # Create Gradio interface
239
+ demo = gr.Interface(
240
+ fn=detect_forgery,
241
+ inputs=gr.File(label="Upload Document (Image or PDF)", file_types=["image", ".pdf"]),
242
+ outputs=[
243
+ gr.Image(type="numpy", label="Detection Result"),
244
+ gr.JSON(label="Detection Details")
245
+ ],
246
+ title="πŸ“„ Document Forgery Detector",
247
+ description="""
248
+ Upload a document image or PDF to detect and classify forgeries.
249
+
250
+ **Supported Formats:**
251
+ - πŸ“· Images: JPG, PNG, BMP, TIFF, WebP
252
+ - πŸ“„ PDF: First page will be analyzed
253
+
254
+ **Supported Forgery Types:**
255
+ - πŸ”΄ Copy-Move: Duplicated regions within the document
256
+ - 🟒 Splicing: Content from different sources
257
+ - πŸ”΅ Generation: AI-generated or synthesized content
258
+
259
+ **Model Performance:**
260
+ - Localization: 75% Dice Score
261
+ - Classification: 92% Accuracy
262
+ """,
263
+ article="""
264
+ ### About
265
+ This model uses a hybrid deep learning approach:
266
+ 1. **Localization**: MobileNetV3-Small + UNet-Lite (detects WHERE)
267
+ 2. **Classification**: LightGBM with hybrid features (detects WHAT)
268
+
269
+ Trained on DocTamper dataset (140K samples).
270
+ """
271
+ )
272
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  if __name__ == "__main__":
275
  demo.launch()
276
+