dhairyashah commited on
Commit
bffe517
1 Parent(s): 7a60200

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -10
app.py CHANGED
@@ -9,7 +9,8 @@ from facenet_pytorch import MTCNN, InceptionResnetV1
9
  import numpy as np
10
  from pytorch_grad_cam import GradCAM
11
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
- import os
 
13
 
14
  app = Flask(__name__)
15
 
@@ -27,12 +28,16 @@ DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
27
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
28
 
29
  model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
30
- # Model Credits: https://huggingface.co/spaces/dhairyashah/deepfake-alpha-version/blob/main/CREDITS.md
31
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
32
  model.load_state_dict(checkpoint['model_state_dict'])
33
  model.to(DEVICE)
34
  model.eval()
35
 
 
 
 
 
 
36
  def allowed_file(filename):
37
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
38
 
@@ -40,7 +45,7 @@ def allowed_file(filename):
40
  def process_frame(frame):
41
  face = mtcnn(frame)
42
  if face is None:
43
- return None, None
44
 
45
  face = face.unsqueeze(0)
46
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
@@ -53,14 +58,22 @@ def process_frame(frame):
53
  output = torch.sigmoid(model(face).squeeze(0))
54
  prediction = "fake" if output.item() >= 0.5 else "real"
55
 
56
- return prediction, output.item()
 
 
 
 
 
 
 
57
 
58
  @spaces.GPU
59
- def analyze_video(video_path, sample_rate=30):
60
  cap = cv2.VideoCapture(video_path)
61
  frame_count = 0
62
  fake_count = 0
63
  total_processed = 0
 
64
 
65
  while cap.isOpened():
66
  ret, frame = cap.read()
@@ -69,22 +82,32 @@ def analyze_video(video_path, sample_rate=30):
69
 
70
  if frame_count % sample_rate == 0:
71
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
72
- prediction, confidence = process_frame(rgb_frame)
73
 
74
  if prediction is not None:
75
  total_processed += 1
76
  if prediction == "fake":
77
  fake_count += 1
78
 
 
 
 
 
 
 
 
79
  frame_count += 1
80
 
81
  cap.release()
82
 
83
  if total_processed > 0:
84
  fake_percentage = (fake_count / total_processed) * 100
85
- return fake_percentage
 
 
 
86
  else:
87
- return 0
88
 
89
  @app.route('/analyze', methods=['POST'])
90
  def analyze_video_api():
@@ -102,12 +125,17 @@ def analyze_video_api():
102
  file.save(filepath)
103
 
104
  try:
105
- fake_percentage = analyze_video(filepath)
106
  os.remove(filepath) # Remove the file after analysis
107
 
 
 
 
 
108
  result = {
109
  'fake_percentage': round(fake_percentage, 2),
110
- 'is_likely_deepfake': fake_percentage >= 60
 
111
  }
112
 
113
  return jsonify(result), 200
 
9
  import numpy as np
10
  from pytorch_grad_cam import GradCAM
11
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
+ from pytorch_grad_cam.utils.image import show_cam_on_image
13
+ import base64
14
 
15
  app = Flask(__name__)
16
 
 
28
  mtcnn = MTCNN(select_largest=False, post_process=False, device=DEVICE).to(DEVICE).eval()
29
 
30
  model = InceptionResnetV1(pretrained="vggface2", classify=True, num_classes=1, device=DEVICE)
 
31
  checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
32
  model.load_state_dict(checkpoint['model_state_dict'])
33
  model.to(DEVICE)
34
  model.eval()
35
 
36
+ # GradCAM setup
37
+ target_layers = [model.block8.branch1[-1]]
38
+ cam = GradCAM(model=model, target_layers=target_layers)
39
+ targets = [ClassifierOutputTarget(0)]
40
+
41
  def allowed_file(filename):
42
  return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
43
 
 
45
  def process_frame(frame):
46
  face = mtcnn(frame)
47
  if face is None:
48
+ return None, None, None
49
 
50
  face = face.unsqueeze(0)
51
  face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
 
58
  output = torch.sigmoid(model(face).squeeze(0))
59
  prediction = "fake" if output.item() >= 0.5 else "real"
60
 
61
+ # Generate GradCAM
62
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
63
+ grayscale_cam = grayscale_cam[0, :]
64
+
65
+ face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
66
+ visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
67
+
68
+ return prediction, output.item(), visualization
69
 
70
  @spaces.GPU
71
+ def analyze_video(video_path, sample_rate=30, top_n=5):
72
  cap = cv2.VideoCapture(video_path)
73
  frame_count = 0
74
  fake_count = 0
75
  total_processed = 0
76
+ frames_info = []
77
 
78
  while cap.isOpened():
79
  ret, frame = cap.read()
 
82
 
83
  if frame_count % sample_rate == 0:
84
  rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
85
+ prediction, confidence, visualization = process_frame(rgb_frame)
86
 
87
  if prediction is not None:
88
  total_processed += 1
89
  if prediction == "fake":
90
  fake_count += 1
91
 
92
+ frames_info.append({
93
+ 'frame_number': frame_count,
94
+ 'prediction': prediction,
95
+ 'confidence': confidence,
96
+ 'visualization': visualization
97
+ })
98
+
99
  frame_count += 1
100
 
101
  cap.release()
102
 
103
  if total_processed > 0:
104
  fake_percentage = (fake_count / total_processed) * 100
105
+ frames_info.sort(key=lambda x: x['confidence'], reverse=True)
106
+ top_frames = frames_info[:top_n]
107
+
108
+ return fake_percentage, top_frames
109
  else:
110
+ return 0, []
111
 
112
  @app.route('/analyze', methods=['POST'])
113
  def analyze_video_api():
 
125
  file.save(filepath)
126
 
127
  try:
128
+ fake_percentage, top_frames = analyze_video(filepath)
129
  os.remove(filepath) # Remove the file after analysis
130
 
131
+ # Convert numpy arrays to base64 encoded strings
132
+ for frame in top_frames:
133
+ frame['visualization'] = base64.b64encode(cv2.imencode('.png', frame['visualization'])[1]).decode('utf-8')
134
+
135
  result = {
136
  'fake_percentage': round(fake_percentage, 2),
137
+ 'is_likely_deepfake': fake_percentage >= 60,
138
+ 'top_frames': top_frames
139
  }
140
 
141
  return jsonify(result), 200