varma123 commited on
Commit
f56b8cf
·
verified ·
1 Parent(s): a96c4c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -35
app.py CHANGED
@@ -20,7 +20,7 @@ mtcnn = MTCNN(
20
  select_largest=False,
21
  post_process=False,
22
  device=DEVICE
23
- ).to(DEVICE).eval()
24
 
25
  model = InceptionResnetV1(
26
  pretrained="vggface2",
@@ -35,8 +35,38 @@ model.to(DEVICE)
35
  model.eval()
36
 
37
  # Model Inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  def predict_video(input_video):
39
- """Predict whether the input video contains real or fake faces"""
40
  cap = cv2.VideoCapture(input_video.name)
41
  frames = []
42
  confidences = []
@@ -45,44 +75,18 @@ def predict_video(input_video):
45
  ret, frame = cap.read()
46
  if not ret:
47
  break
48
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
- frame_pil = Image.fromarray(frame)
50
 
51
- face = mtcnn(frame_pil)
52
- if face is None:
53
- raise Exception('No face detected')
54
- face = face.unsqueeze(0) # add the batch dimension
55
- face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
56
 
57
- face = face.to(DEVICE, dtype=torch.float32) / 255.0
58
-
59
- target_layers = [model.block8.branch1[-1]]
60
- use_cuda = True if torch.cuda.is_available() else False
61
- cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
62
- targets = [ClassifierOutputTarget(0)]
63
-
64
- grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
65
- grayscale_cam = grayscale_cam[0, :]
66
- visualization = show_cam_on_image(frame, grayscale_cam, use_rgb=True)
67
- face_with_mask = cv2.addWeighted(frame, 1, visualization, 0.5, 0)
68
-
69
- with torch.no_grad():
70
- output = torch.sigmoid(model(face).squeeze(0))
71
- prediction = "real" if output.item() < 0.5 else "fake"
72
-
73
- real_prediction = 1 - output.item()
74
- fake_prediction = output.item()
75
-
76
- confidences.append({
77
- 'real': real_prediction,
78
- 'fake': fake_prediction
79
- })
80
-
81
- frames.append(face_with_mask)
82
 
83
  cap.release()
84
 
85
- return confidences, frames
 
 
 
86
 
87
 
88
  # Gradio Interface
 
20
  select_largest=False,
21
  post_process=False,
22
  device=DEVICE
23
+ ).eval()
24
 
25
  model = InceptionResnetV1(
26
  pretrained="vggface2",
 
35
  model.eval()
36
 
37
  # Model Inference
38
+ def predict_frame(frame):
39
+ """Predict whether the input frame contains real or fake faces"""
40
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
41
+ frame_pil = Image.fromarray(frame)
42
+
43
+ face = mtcnn(frame_pil)
44
+ if face is None:
45
+ raise Exception('No face detected')
46
+ face = face.unsqueeze(0) # add the batch dimension
47
+ face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)
48
+
49
+ face = face.to(DEVICE, dtype=torch.float32) / 255.0
50
+
51
+ target_layers = [model.block8.branch1[-1]]
52
+ use_cuda = True if torch.cuda.is_available() else False
53
+ cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
54
+ targets = [ClassifierOutputTarget(0)]
55
+
56
+ grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
57
+ grayscale_cam = grayscale_cam[0, :]
58
+ visualization = show_cam_on_image(frame, grayscale_cam, use_rgb=True)
59
+ face_with_mask = cv2.addWeighted(frame, 1, visualization, 0.5, 0)
60
+
61
+ with torch.no_grad():
62
+ output = torch.sigmoid(model(face).squeeze(0))
63
+ prediction = "real" if output.item() < 0.5 else "fake"
64
+
65
+ return prediction, face_with_mask
66
+
67
+
68
+ # Function to process video
69
  def predict_video(input_video):
 
70
  cap = cv2.VideoCapture(input_video.name)
71
  frames = []
72
  confidences = []
 
75
  ret, frame = cap.read()
76
  if not ret:
77
  break
 
 
78
 
79
+ prediction, frame_with_mask = predict_frame(frame)
 
 
 
 
80
 
81
+ frames.append(frame_with_mask)
82
+ confidences.append(prediction)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  cap.release()
85
 
86
+ # Determine the final prediction based on the maximum occurrence of predictions
87
+ final_prediction = max(set(confidences), key=confidences.count)
88
+
89
+ return final_prediction, frames
90
 
91
 
92
  # Gradio Interface