PraneshJs commited on
Commit
996bc38
·
verified ·
1 Parent(s): 6d0c2aa

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +48 -68
inference_2.py CHANGED
@@ -1,21 +1,15 @@
1
  import os
2
  import cv2
3
- import onnx
4
  import torch
5
  import numpy as np
6
- import argparse
7
- from models.TMC import ETMC
8
- from models import image
9
  from onnx2pytorch import ConvertModel
10
- import types
11
-
12
- # Load ONNX model and convert to PyTorch
13
- onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
- pytorch_model = ConvertModel(onnx_model)
15
 
 
16
  torch.manual_seed(42)
17
 
18
- # Audio model parameters
19
  audio_args = {
20
  'nb_samp': 64600,
21
  'first_conv': 1024,
@@ -25,50 +19,45 @@ audio_args = {
25
  'nb_fc_node': 1024,
26
  'gru_node': 1024,
27
  'nb_gru_layer': 3,
28
- 'nb_classes': 2
 
 
29
  }
30
 
31
- # Create a complete args object for RawNet
32
- audio_args_complete = {
33
- **audio_args,
34
- 'pretrained_audio_encoder': False,
35
- 'freeze_audio_encoder': False,
36
- 'device': 'cpu'
37
- }
38
- audio_args_obj = types.SimpleNamespace(**audio_args_complete)
39
 
40
- # Load models
 
 
 
 
41
  spec_model = image.RawNet(audio_args_obj)
42
- spec_model_ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
43
- spec_model.load_state_dict(spec_model_ckpt['spec_encoder'], strict=True)
44
- spec_model.eval()
45
 
46
- img_model = pytorch_model
47
- img_model_ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
48
- img_model.load_state_dict(img_model_ckpt['rgb_encoder'], strict=True)
49
  img_model.eval()
 
50
 
51
-
52
  # Preprocessing functions
 
53
  def preprocess_img(face):
54
  face = face / 255.0
55
  face = cv2.resize(face, (256, 256))
56
- face_pt = torch.unsqueeze(torch.Tensor(face), dim=0)
57
- return face_pt
58
-
59
 
60
  def preprocess_audio(audio_file):
61
- audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
62
- return audio_pt
63
-
64
 
65
  def preprocess_video(input_video, n_frames=3):
66
  v_cap = cv2.VideoCapture(input_video)
67
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
-
69
  sample = np.linspace(0, v_len - 1, n_frames).astype(int)
70
-
71
  frames = []
 
72
  for j in range(v_len):
73
  success = v_cap.grab()
74
  if j in sample:
@@ -81,52 +70,43 @@ def preprocess_video(input_video, n_frames=3):
81
  v_cap.release()
82
  return frames
83
 
84
-
85
  # Prediction functions
 
86
  def deepfakes_spec_predict(input_audio):
87
- x, _ = input_audio
88
- audio = preprocess_audio(x)
89
- spec_grads = spec_model(audio)
90
- spec_grads_np = np.squeeze(spec_grads.detach().cpu().numpy())
91
 
92
  if spec_grads_np[0] > 0.5:
93
- text2 = f"The audio is REAL."
94
  else:
95
- text2 = f"The audio is FAKE."
96
- return text2
97
-
98
 
99
  def deepfakes_image_predict(input_image):
100
- face = preprocess_img(input_image)
101
- img_grads = img_model(face)
102
- img_grads_np = np.squeeze(img_grads.detach().cpu().numpy())
103
 
104
  if img_grads_np[0] > 0.5:
105
- preds = round(img_grads_np[0] * 100, 3)
106
- text2 = f"The image is REAL. \nConfidence score: {preds}%"
107
  else:
108
- preds = round(img_grads_np[1] * 100, 3)
109
- text2 = f"The image is FAKE. \nConfidence score: {preds}%"
110
- return text2
111
-
112
 
113
  def deepfakes_video_predict(input_video):
114
- video_frames = preprocess_video(input_video)
115
- real_faces_list, fake_faces_list = [], []
116
 
117
- for face in video_frames:
118
- img_grads = img_model(face)
119
- img_grads_np = np.squeeze(img_grads.detach().cpu().numpy())
120
- real_faces_list.append(img_grads_np[0])
121
- fake_faces_list.append(img_grads_np[1])
122
 
123
- real_faces_mean = np.mean(real_faces_list)
124
- fake_faces_mean = np.mean(fake_faces_list)
125
 
126
- if real_faces_mean > 0.5:
127
- preds = round(real_faces_mean * 100, 3)
128
- text2 = f"The video is REAL. \nConfidence score: {preds}%"
129
  else:
130
- preds = round(fake_faces_mean * 100, 3)
131
- text2 = f"The video is FAKE. \nConfidence score: {preds}%"
132
- return text2
 
1
  import os
2
  import cv2
 
3
  import torch
4
  import numpy as np
5
+ from onnx import load as onnx_load
 
 
6
  from onnx2pytorch import ConvertModel
7
+ from models import image # Your RawNet audio model
 
 
 
 
8
 
9
+ # Set seed for reproducibility
10
  torch.manual_seed(42)
11
 
12
+ # Audio args for RawNet
13
  audio_args = {
14
  'nb_samp': 64600,
15
  'first_conv': 1024,
 
19
  'nb_fc_node': 1024,
20
  'gru_node': 1024,
21
  'nb_gru_layer': 3,
22
+ 'nb_classes': 2,
23
+ 'device': 'cpu',
24
+ 'pretrained_audio_encoder': False
25
  }
26
 
27
+ # Convert audio_args dict to a namespace object
28
+ from types import SimpleNamespace
29
+ audio_args_obj = SimpleNamespace(**audio_args)
 
 
 
 
 
30
 
31
+ # Load ONNX → PyTorch model for images
32
+ onnx_model = onnx_load("checkpoints/efficientnet.onnx")
33
+ img_model = ConvertModel(onnx_model) # do NOT use strict=True (not supported)
34
+
35
+ # Load Audio model
36
  spec_model = image.RawNet(audio_args_obj)
 
 
 
37
 
38
+ # Ensure models are in eval mode
 
 
39
  img_model.eval()
40
+ spec_model.eval()
41
 
42
+ # -------------------------
43
  # Preprocessing functions
44
+ # -------------------------
45
  def preprocess_img(face):
46
  face = face / 255.0
47
  face = cv2.resize(face, (256, 256))
48
+ face_tensor = torch.unsqueeze(torch.Tensor(face), dim=0)
49
+ return face_tensor
 
50
 
51
  def preprocess_audio(audio_file):
52
+ audio_tensor = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
53
+ return audio_tensor
 
54
 
55
  def preprocess_video(input_video, n_frames=3):
56
  v_cap = cv2.VideoCapture(input_video)
57
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
58
  sample = np.linspace(0, v_len - 1, n_frames).astype(int)
 
59
  frames = []
60
+
61
  for j in range(v_len):
62
  success = v_cap.grab()
63
  if j in sample:
 
70
  v_cap.release()
71
  return frames
72
 
73
+ # -------------------------
74
  # Prediction functions
75
+ # -------------------------
76
  def deepfakes_spec_predict(input_audio):
77
+ audio_tensor = preprocess_audio(input_audio)
78
+ spec_grads = spec_model.forward(audio_tensor)
79
+ spec_grads_np = np.squeeze(spec_grads.cpu().detach().numpy())
 
80
 
81
  if spec_grads_np[0] > 0.5:
82
+ return "The audio is REAL."
83
  else:
84
+ return "The audio is FAKE."
 
 
85
 
86
  def deepfakes_image_predict(input_image):
87
+ face_tensor = preprocess_img(input_image)
88
+ img_grads = img_model.forward(face_tensor)
89
+ img_grads_np = np.squeeze(img_grads.cpu().detach().numpy())
90
 
91
  if img_grads_np[0] > 0.5:
92
+ return f"The image is REAL. Confidence score: {round(img_grads_np[0]*100,2)}%"
 
93
  else:
94
+ return f"The image is FAKE. Confidence score: {round(img_grads_np[1]*100,2)}%"
 
 
 
95
 
96
  def deepfakes_video_predict(input_video):
97
+ frames = preprocess_video(input_video)
98
+ real_list, fake_list = [], []
99
 
100
+ for frame in frames:
101
+ img_grads = img_model.forward(frame)
102
+ img_grads_np = np.squeeze(img_grads.cpu().detach().numpy())
103
+ real_list.append(img_grads_np[0])
104
+ fake_list.append(img_grads_np[1])
105
 
106
+ real_mean = np.mean(real_list)
107
+ fake_mean = np.mean(fake_list)
108
 
109
+ if real_mean > 0.5:
110
+ return f"The video is REAL. Confidence: {round(real_mean*100,2)}%"
 
111
  else:
112
+ return f"The video is FAKE. Confidence: {round(fake_mean*100,2)}%"