Shiwanni commited on
Commit
47f4e7c
·
verified ·
1 Parent(s): f113b15

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -172
app.py CHANGED
@@ -1,87 +1,132 @@
1
  import gradio as gr
2
  import torch
 
 
3
  from transformers import AutoImageProcessor, AutoModelForImageClassification
4
  from PIL import Image
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
  import time
8
 
9
- # Load model and processor
10
- model_name = "dima806/deepfake_vs_real_image_detection"
11
- processor = AutoImageProcessor.from_pretrained(model_name)
12
- model = AutoModelForImageClassification.from_pretrained(model_name)
 
 
13
 
14
- def predict_image(image):
 
 
 
 
 
 
 
 
 
 
 
15
  if image is None:
16
- return {"result": "⚠️ Please upload an image", "class": "neutral"}
17
 
18
  try:
19
- # Show loading animation
20
- time.sleep(1) # Simulate processing time
21
-
22
- # Convert image to RGB if needed
23
  if image.mode != "RGB":
24
  image = image.convert("RGB")
25
 
26
- # Preprocess and predict
 
 
 
 
 
 
 
27
  inputs = processor(images=image, return_tensors="pt")
 
 
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
 
31
  # Get probabilities
32
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
33
- real_prob = probs[model.config.label2id["real"]].item()
34
- fake_prob = probs[model.config.label2id["fake"]].item()
35
 
36
- # Create visualization
37
- fig = create_visualization(real_prob, fake_prob)
 
38
 
39
- if real_prob > fake_prob:
40
- return {
41
- "result": f"🎉 Authentic Image (Confidence: {real_prob*100:.1f}%)",
42
- "class": "authentic",
43
- "details": f"Real Probability: {real_prob*100:.1f}% | Fake Probability: {fake_prob*100:.1f}%",
44
- "visualization": fig
45
- }
46
- else:
47
- return {
48
- "result": f"⚠️ Potential Deepfake (Confidence: {fake_prob*100:.1f}%)",
49
- "class": "fake",
50
- "details": f"Fake Probability: {fake_prob*100:.1f}% | Real Probability: {real_prob*100:.1f}%",
51
- "visualization": fig
52
- }
53
  except Exception as e:
54
- return {"result": f"Error: {str(e)}", "class": "error"}
55
 
56
- def create_visualization(real_prob, fake_prob):
57
- labels = ['Real', 'Fake']
58
- probabilities = [real_prob, fake_prob]
59
- colors = ['#4CAF50', '#F44336']
60
 
61
- fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
 
 
 
62
 
63
- # Bar chart
64
- bars = ax1.bar(labels, probabilities, color=colors)
65
- ax1.set_ylim(0, 1)
66
- ax1.set_title('Probability Distribution')
67
- ax1.set_ylabel('Probability')
68
 
69
- # Add value labels on top of bars
70
- for bar in bars:
71
- height = bar.get_height()
72
- ax1.text(bar.get_x() + bar.get_width()/2., height,
73
- f'{height:.2f}',
74
- ha='center', va='bottom')
 
75
 
76
- # Pie chart
77
- ax2.pie(probabilities, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
78
- ax2.axis('equal')
79
- ax2.set_title('Probability Ratio')
 
 
 
 
 
 
 
 
80
 
81
  plt.tight_layout()
82
  return fig
83
 
84
- # Custom CSS for attractive UI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  custom_css = """
86
  :root {
87
  --primary: #4b6cb7;
@@ -89,11 +134,10 @@ custom_css = """
89
  --authentic: #4CAF50;
90
  --fake: #F44336;
91
  --neutral: #2196F3;
92
- --error: #FF9800;
93
  }
94
 
95
  #main-container {
96
- max-width: 900px;
97
  margin: auto;
98
  padding: 25px;
99
  border-radius: 15px;
@@ -126,35 +170,18 @@ custom_css = """
126
  padding: 20px;
127
  border-radius: 12px;
128
  margin-top: 20px;
129
- font-size: 1.2em;
130
- font-weight: bold;
131
- text-align: center;
132
  transition: all 0.3s ease;
133
  box-shadow: 0 4px 6px rgba(0,0,0,0.1);
 
134
  }
135
 
136
- .authentic {
137
- background: linear-gradient(135deg, #e6f7e6 0%, #c8e6c9 100%);
138
- color: var(--authentic);
139
- border-left: 5px solid var(--authentic);
140
- }
141
-
142
- .fake {
143
- background: linear-gradient(135deg, #ffebee 0%, #ffcdd2 100%);
144
- color: var(--fake);
145
- border-left: 5px solid var(--fake);
146
- }
147
-
148
- .neutral {
149
- background: linear-gradient(135deg, #e3f2fd 0%, #bbdefb 100%);
150
- color: var(--neutral);
151
- border-left: 5px solid var(--neutral);
152
- }
153
-
154
- .error {
155
- background: linear-gradient(135deg, #fff3e0 0%, #ffe0b2 100%);
156
- color: var(--error);
157
- border-left: 5px solid var(--error);
158
  }
159
 
160
  .btn-primary {
@@ -166,17 +193,11 @@ custom_css = """
166
  font-weight: bold !important;
167
  }
168
 
169
- .btn-primary:hover {
170
- transform: translateY(-2px);
171
- box-shadow: 0 4px 8px rgba(0,0,0,0.2) !important;
172
- }
173
-
174
- .btn-secondary {
175
  background: white !important;
176
- color: var(--primary) !important;
177
  border: 2px solid var(--primary) !important;
178
- padding: 12px 24px !important;
179
  border-radius: 8px !important;
 
180
  }
181
 
182
  .footer {
@@ -186,24 +207,16 @@ custom_css = """
186
  color: #666;
187
  }
188
 
189
- .animation {
190
- animation: fadeIn 0.5s ease-in-out;
191
- }
192
-
193
- .visualization-box {
194
- border-radius: 12px;
195
- padding: 15px;
196
- background: white;
197
- margin-top: 15px;
198
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
199
- }
200
-
201
  @keyframes fadeIn {
202
  from { opacity: 0; transform: translateY(10px); }
203
  to { opacity: 1; transform: translateY(0); }
204
  }
205
 
206
- .processing-animation {
 
 
 
 
207
  animation: pulse 1.5s infinite;
208
  }
209
 
@@ -218,93 +231,66 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
218
  with gr.Column(elem_id="main-container"):
219
  with gr.Column(elem_classes=["header"]):
220
  gr.Markdown("# 🛡️ DeepGuard AI")
221
- gr.Markdown("Detect AI-generated images with advanced deep learning")
222
 
223
  with gr.Row():
224
- with gr.Column(scale=2):
225
  image_input = gr.Image(
226
  type="pil",
227
- label="Drag & drop your image here",
228
  elem_classes=["upload-area", "animation"]
229
  )
 
230
  with gr.Row():
231
- submit_btn = gr.Button("Analyze Now", elem_classes=["btn-primary", "animation"])
232
- clear_btn = gr.Button("Clear", elem_classes=["btn-secondary", "animation"])
 
 
 
 
 
 
 
 
233
 
234
  with gr.Column(scale=1):
235
- result_display = gr.Textbox(
236
- label="Analysis Result",
237
- interactive=False,
238
- elem_classes=["result-box", "neutral", "animation"]
239
- )
240
- details = gr.Textbox(
241
- label="Detailed Analysis",
242
- interactive=False,
243
- visible=False
244
- )
245
 
246
- visualization = gr.Plot(
247
- label="Probability Visualization",
248
- visible=False,
249
- elem_classes=["visualization-box"]
250
- )
 
 
 
 
 
 
 
 
251
 
252
  gr.Markdown("""
253
  <div class="footer">
254
- *Note: This AI detection tool provides estimates based on image analysis.
255
- Results should be verified with additional methods for critical decisions.*
256
  </div>
257
  """)
258
-
259
- def update_ui(image, prediction):
260
- if image is None:
261
- return {
262
- result_display: gr.Textbox.update(
263
- value="📁 Upload an image to begin analysis",
264
- elem_classes=["result-box", "neutral"]
265
- ),
266
- details: gr.Textbox.update(visible=False),
267
- visualization: gr.Plot.update(visible=False)
268
- }
269
-
270
- show_visualization = "visualization" in prediction
271
- show_details = "details" in prediction
272
-
273
- updates = {
274
- result_display: gr.Textbox.update(
275
- value=prediction["result"],
276
- elem_classes=["result-box", prediction["class"]]
277
- ),
278
- details: gr.Textbox.update(
279
- value=prediction.get("details", ""),
280
- visible=show_details
281
- ),
282
- visualization: gr.Plot.update(
283
- value=prediction.get("visualization"),
284
- visible=show_visualization
285
- )
286
- }
287
-
288
- return updates
289
-
290
- submit_btn.click(
291
- fn=predict_image,
292
- inputs=image_input,
293
- outputs=gr.JSON(visible=False)
294
- ).then(
295
- fn=update_ui,
296
- inputs=[image_input, gr.JSON()],
297
- outputs=[result_display, details, visualization]
298
- )
299
 
300
- clear_btn.click(
301
- fn=lambda: [None, {"result": "📁 Upload an image to begin analysis", "class": "neutral"}],
302
- inputs=None,
303
- outputs=[image_input, gr.JSON()]
304
- ).then(
305
- fn=update_ui,
306
- inputs=[image_input, gr.JSON()],
307
- outputs=[result_display, details, visualization]
308
  )
309
 
310
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
  from transformers import AutoImageProcessor, AutoModelForImageClassification
6
  from PIL import Image
7
+ import cv2
8
+ from skimage import exposure
9
  import time
10
 
11
+ # Load models (using free Hugging Face models)
12
+ MODEL_NAMES = {
13
+ "Model 1": "dima806/deepfake_vs_real_image_detection",
14
+ "Model 2": "saltacc/anime-ai-detect",
15
+ "Model 3": "rizvandwiki/gansfake-detector"
16
+ }
17
 
18
+ # Initialize models
19
+ models = {}
20
+ processors = {}
21
+
22
+ for name, path in MODEL_NAMES.items():
23
+ try:
24
+ processors[name] = AutoImageProcessor.from_pretrained(path)
25
+ models[name] = AutoModelForImageClassification.from_pretrained(path)
26
+ except:
27
+ print(f"Could not load model: {name}")
28
+
29
+ def analyze_image(image, selected_model):
30
  if image is None:
31
+ return None, None, "Please upload an image first", None
32
 
33
  try:
34
+ # Convert to RGB if needed
 
 
 
35
  if image.mode != "RGB":
36
  image = image.convert("RGB")
37
 
38
+ # Get model and processor
39
+ model = models.get(selected_model)
40
+ processor = processors.get(selected_model)
41
+
42
+ if not model or not processor:
43
+ return None, None, "Selected model not available", None
44
+
45
+ # Preprocess image
46
  inputs = processor(images=image, return_tensors="pt")
47
+
48
+ # Predict
49
  with torch.no_grad():
50
  outputs = model(**inputs)
51
 
52
  # Get probabilities
53
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
 
 
54
 
55
+ # Create visualizations
56
+ heatmap = generate_heatmap(image, model, processor)
57
+ chart_fig = create_probability_chart(probs, model.config.id2label)
58
 
59
+ # Format results
60
+ result_text = format_results(probs, model.config.id2label)
61
+
62
+ return heatmap, chart_fig, result_text, create_model_info(selected_model)
63
+
 
 
 
 
 
 
 
 
 
64
  except Exception as e:
65
+ return None, None, f"Error: {str(e)}", None
66
 
67
+ def generate_heatmap(image, model, processor):
68
+ """Generate a heatmap showing important regions for the prediction"""
69
+ # Convert to numpy array
70
+ img_array = np.array(image)
71
 
72
+ # Create a saliency map (simple version)
73
+ gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
74
+ blurred = cv2.GaussianBlur(gray, (21, 21), 0)
75
+ heatmap = cv2.applyColorMap(blurred, cv2.COLORMAP_JET)
76
 
77
+ # Blend with original image
78
+ heatmap = cv2.addWeighted(img_array, 0.7, heatmap, 0.3, 0)
 
 
 
79
 
80
+ # Convert back to PIL
81
+ return Image.fromarray(heatmap)
82
+
83
+ def create_probability_chart(probs, id2label):
84
+ """Create a bar chart of class probabilities"""
85
+ labels = [id2label[i] for i in range(len(probs))]
86
+ colors = ['#4CAF50' if 'real' in label.lower() else '#F44336' for label in labels]
87
 
88
+ fig, ax = plt.subplots(figsize=(8, 4))
89
+ bars = ax.barh(labels, probs.numpy(), color=colors)
90
+ ax.set_xlim(0, 1)
91
+ ax.set_title('Detection Probabilities', pad=20)
92
+ ax.set_xlabel('Probability')
93
+
94
+ # Add value labels
95
+ for bar in bars:
96
+ width = bar.get_width()
97
+ ax.text(width + 0.02, bar.get_y() + bar.get_height()/2,
98
+ f'{width:.2f}',
99
+ va='center')
100
 
101
  plt.tight_layout()
102
  return fig
103
 
104
+ def format_results(probs, id2label):
105
+ """Format the results as text"""
106
+ results = []
107
+ for i, prob in enumerate(probs):
108
+ results.append(f"{id2label[i]}: {prob*100:.1f}%")
109
+
110
+ max_prob = max(probs)
111
+ max_class = id2label[torch.argmax(probs).item()]
112
+
113
+ if 'real' in max_class.lower():
114
+ conclusion = f"Conclusion: This image appears to be AUTHENTIC with {max_prob*100:.1f}% confidence"
115
+ else:
116
+ conclusion = f"Conclusion: This image appears to be FAKE/GENERATED with {max_prob*100:.1f}% confidence"
117
+
118
+ return "\n".join([conclusion, "", "Detailed probabilities:"] + results)
119
+
120
+ def create_model_info(model_name):
121
+ """Create information about the current model"""
122
+ info = {
123
+ "Model 1": "Trained to detect deepfakes vs real human faces",
124
+ "Model 2": "Specialized in detecting AI-generated anime images",
125
+ "Model 3": "General GAN-generated image detector"
126
+ }
127
+ return info.get(model_name, "No information available about this model")
128
+
129
+ # Custom CSS for the interface
130
  custom_css = """
131
  :root {
132
  --primary: #4b6cb7;
 
134
  --authentic: #4CAF50;
135
  --fake: #F44336;
136
  --neutral: #2196F3;
 
137
  }
138
 
139
  #main-container {
140
+ max-width: 1200px;
141
  margin: auto;
142
  padding: 25px;
143
  border-radius: 15px;
 
170
  padding: 20px;
171
  border-radius: 12px;
172
  margin-top: 20px;
173
+ font-size: 1.1em;
 
 
174
  transition: all 0.3s ease;
175
  box-shadow: 0 4px 6px rgba(0,0,0,0.1);
176
+ background: white;
177
  }
178
 
179
+ .visualization-box {
180
+ border-radius: 12px;
181
+ padding: 15px;
182
+ background: white;
183
+ margin-top: 15px;
184
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  }
186
 
187
  .btn-primary {
 
193
  font-weight: bold !important;
194
  }
195
 
196
+ .model-select {
 
 
 
 
 
197
  background: white !important;
 
198
  border: 2px solid var(--primary) !important;
 
199
  border-radius: 8px !important;
200
+ padding: 8px 12px !important;
201
  }
202
 
203
  .footer {
 
207
  color: #666;
208
  }
209
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  @keyframes fadeIn {
211
  from { opacity: 0; transform: translateY(10px); }
212
  to { opacity: 1; transform: translateY(0); }
213
  }
214
 
215
+ .animation {
216
+ animation: fadeIn 0.5s ease-in-out;
217
+ }
218
+
219
+ .loading {
220
  animation: pulse 1.5s infinite;
221
  }
222
 
 
231
  with gr.Column(elem_id="main-container"):
232
  with gr.Column(elem_classes=["header"]):
233
  gr.Markdown("# 🛡️ DeepGuard AI")
234
+ gr.Markdown("## Advanced Deepfake Detection System")
235
 
236
  with gr.Row():
237
+ with gr.Column(scale=1.5):
238
  image_input = gr.Image(
239
  type="pil",
240
+ label="Upload Image for Analysis",
241
  elem_classes=["upload-area", "animation"]
242
  )
243
+
244
  with gr.Row():
245
+ model_selector = gr.Dropdown(
246
+ choices=list(MODEL_NAMES.keys()),
247
+ value=list(MODEL_NAMES.keys())[0],
248
+ label="Select Detection Model",
249
+ elem_classes=["model-select", "animation"]
250
+ )
251
+ analyze_btn = gr.Button(
252
+ "Analyze Image",
253
+ elem_classes=["btn-primary", "animation"]
254
+ )
255
 
256
  with gr.Column(scale=1):
257
+ with gr.Column(elem_classes=["visualization-box"]):
258
+ heatmap_output = gr.Image(
259
+ label="Attention Heatmap",
260
+ interactive=False
261
+ )
262
+
263
+ with gr.Column(elem_classes=["visualization-box"]):
264
+ chart_output = gr.Plot(
265
+ label="Detection Probabilities"
266
+ )
267
 
268
+ with gr.Column(elem_classes=["result-box", "animation"]):
269
+ result_output = gr.Textbox(
270
+ label="Analysis Results",
271
+ interactive=False,
272
+ lines=8
273
+ )
274
+
275
+ with gr.Column(elem_classes=["result-box", "animation"]):
276
+ model_info = gr.Textbox(
277
+ label="Model Information",
278
+ interactive=False,
279
+ lines=3
280
+ )
281
 
282
  gr.Markdown("""
283
  <div class="footer">
284
+ *Note: This tool provides probabilistic estimates. Always verify important findings with additional methods.<br>
285
+ Models may produce false positives/negatives. Performance varies by image type and quality.*
286
  </div>
287
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ # Event handlers
290
+ analyze_btn.click(
291
+ fn=analyze_image,
292
+ inputs=[image_input, model_selector],
293
+ outputs=[heatmap_output, chart_output, result_output, model_info]
 
 
 
294
  )
295
 
296
  if __name__ == "__main__":