PraneshJs commited on
Commit
6d0c2aa
·
verified ·
1 Parent(s): 2b0e2be

Update inference_2.py

Browse files
Files changed (1) hide show
  1. inference_2.py +64 -59
inference_2.py CHANGED
@@ -3,22 +3,19 @@ import cv2
3
  import onnx
4
  import torch
5
  import numpy as np
6
- from types import SimpleNamespace
7
- from onnx2pytorch import ConvertModel
8
  from models.TMC import ETMC
9
  from models import image
 
 
10
 
11
- # -----------------------------
12
- # Load ONNX -> PyTorch safely
13
- # -----------------------------
14
  onnx_model = onnx.load('checkpoints/efficientnet.onnx')
15
  pytorch_model = ConvertModel(onnx_model)
16
 
17
  torch.manual_seed(42)
18
 
19
- # -----------------------------
20
- # Audio model arguments
21
- # -----------------------------
22
  audio_args = {
23
  'nb_samp': 64600,
24
  'first_conv': 1024,
@@ -28,53 +25,49 @@ audio_args = {
28
  'nb_fc_node': 1024,
29
  'gru_node': 1024,
30
  'nb_gru_layer': 3,
31
- 'nb_classes': 2,
 
 
 
 
 
 
 
32
  'device': 'cpu'
33
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- audio_args_obj = SimpleNamespace(**audio_args)
36
-
37
- # -----------------------------
38
- # Load Audio Model
39
- # -----------------------------
40
- def load_audio_model():
41
- spec_model = image.RawNet(audio_args_obj)
42
- ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
43
- spec_model.load_state_dict(ckpt['spec_encoder'], strict=True)
44
- spec_model.eval()
45
- return spec_model
46
-
47
- spec_model = load_audio_model()
48
-
49
- # -----------------------------
50
- # Load Image Model
51
- # -----------------------------
52
- def load_image_model():
53
- rgb_encoder = pytorch_model
54
- ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
55
- rgb_encoder.load_state_dict(ckpt['rgb_encoder'], strict=True)
56
- rgb_encoder.eval()
57
- return rgb_encoder
58
-
59
- img_model = load_image_model()
60
-
61
- # -----------------------------
62
  # Preprocessing functions
63
- # -----------------------------
64
  def preprocess_img(face):
65
  face = face / 255.0
66
  face = cv2.resize(face, (256, 256))
67
  face_pt = torch.unsqueeze(torch.Tensor(face), dim=0)
68
  return face_pt
69
 
 
70
  def preprocess_audio(audio_file):
71
  audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
72
  return audio_pt
73
 
 
74
  def preprocess_video(input_video, n_frames=3):
75
  v_cap = cv2.VideoCapture(input_video)
76
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
77
- sample = np.linspace(0, v_len-1, n_frames).astype(int)
 
 
78
  frames = []
79
  for j in range(v_len):
80
  success = v_cap.grab()
@@ -88,40 +81,52 @@ def preprocess_video(input_video, n_frames=3):
88
  v_cap.release()
89
  return frames
90
 
91
- # -----------------------------
92
- # Inference functions
93
- # -----------------------------
94
  def deepfakes_spec_predict(input_audio):
95
- audio = preprocess_audio(input_audio)
96
- spec_grads = spec_model.forward(audio)
97
- spec_grads_np = np.exp(spec_grads.cpu().detach().numpy().squeeze())
98
- max_value = np.argmax(spec_grads_np)
99
- if max_value > 0.5:
 
100
  text2 = f"The audio is REAL."
101
  else:
102
  text2 = f"The audio is FAKE."
103
  return text2
104
 
 
105
  def deepfakes_image_predict(input_image):
106
  face = preprocess_img(input_image)
107
- img_grads = img_model.forward(face).cpu().detach().numpy().squeeze()
108
- if img_grads[0] > 0.5:
109
- text2 = f"The image is REAL. Confidence: {img_grads[0]*100:.3f}%"
 
 
 
110
  else:
111
- text2 = f"The image is FAKE. Confidence: {img_grads[1]*100:.3f}%"
 
112
  return text2
113
 
 
114
  def deepfakes_video_predict(input_video):
115
  video_frames = preprocess_video(input_video)
116
- real_list, fake_list = [], []
 
117
  for face in video_frames:
118
- img_grads = img_model.forward(face).cpu().detach().numpy().squeeze()
119
- real_list.append(img_grads[0])
120
- fake_list.append(img_grads[1])
121
- real_mean = np.mean(real_list)
122
- fake_mean = np.mean(fake_list)
123
- if real_mean > 0.5:
124
- text2 = f"The video is REAL. Confidence: {real_mean*100:.3f}%"
 
 
 
 
125
  else:
126
- text2 = f"The video is FAKE. Confidence: {fake_mean*100:.3f}%"
 
127
  return text2
 
3
  import onnx
4
  import torch
5
  import numpy as np
6
+ import argparse
 
7
  from models.TMC import ETMC
8
  from models import image
9
+ from onnx2pytorch import ConvertModel
10
+ import types
11
 
12
+ # Load ONNX model and convert to PyTorch
 
 
13
  onnx_model = onnx.load('checkpoints/efficientnet.onnx')
14
  pytorch_model = ConvertModel(onnx_model)
15
 
16
  torch.manual_seed(42)
17
 
18
+ # Audio model parameters
 
 
19
  audio_args = {
20
  'nb_samp': 64600,
21
  'first_conv': 1024,
 
25
  'nb_fc_node': 1024,
26
  'gru_node': 1024,
27
  'nb_gru_layer': 3,
28
+ 'nb_classes': 2
29
+ }
30
+
31
+ # Create a complete args object for RawNet
32
+ audio_args_complete = {
33
+ **audio_args,
34
+ 'pretrained_audio_encoder': False,
35
+ 'freeze_audio_encoder': False,
36
  'device': 'cpu'
37
  }
38
+ audio_args_obj = types.SimpleNamespace(**audio_args_complete)
39
+
40
+ # Load models
41
+ spec_model = image.RawNet(audio_args_obj)
42
+ spec_model_ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
43
+ spec_model.load_state_dict(spec_model_ckpt['spec_encoder'], strict=True)
44
+ spec_model.eval()
45
+
46
+ img_model = pytorch_model
47
+ img_model_ckpt = torch.load('checkpoints/model.pth', map_location='cpu')
48
+ img_model.load_state_dict(img_model_ckpt['rgb_encoder'], strict=True)
49
+ img_model.eval()
50
+
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # Preprocessing functions
 
53
  def preprocess_img(face):
54
  face = face / 255.0
55
  face = cv2.resize(face, (256, 256))
56
  face_pt = torch.unsqueeze(torch.Tensor(face), dim=0)
57
  return face_pt
58
 
59
+
60
  def preprocess_audio(audio_file):
61
  audio_pt = torch.unsqueeze(torch.Tensor(audio_file), dim=0)
62
  return audio_pt
63
 
64
+
65
  def preprocess_video(input_video, n_frames=3):
66
  v_cap = cv2.VideoCapture(input_video)
67
  v_len = int(v_cap.get(cv2.CAP_PROP_FRAME_COUNT))
68
+
69
+ sample = np.linspace(0, v_len - 1, n_frames).astype(int)
70
+
71
  frames = []
72
  for j in range(v_len):
73
  success = v_cap.grab()
 
81
  v_cap.release()
82
  return frames
83
 
84
+
85
+ # Prediction functions
 
86
  def deepfakes_spec_predict(input_audio):
87
+ x, _ = input_audio
88
+ audio = preprocess_audio(x)
89
+ spec_grads = spec_model(audio)
90
+ spec_grads_np = np.squeeze(spec_grads.detach().cpu().numpy())
91
+
92
+ if spec_grads_np[0] > 0.5:
93
  text2 = f"The audio is REAL."
94
  else:
95
  text2 = f"The audio is FAKE."
96
  return text2
97
 
98
+
99
  def deepfakes_image_predict(input_image):
100
  face = preprocess_img(input_image)
101
+ img_grads = img_model(face)
102
+ img_grads_np = np.squeeze(img_grads.detach().cpu().numpy())
103
+
104
+ if img_grads_np[0] > 0.5:
105
+ preds = round(img_grads_np[0] * 100, 3)
106
+ text2 = f"The image is REAL. \nConfidence score: {preds}%"
107
  else:
108
+ preds = round(img_grads_np[1] * 100, 3)
109
+ text2 = f"The image is FAKE. \nConfidence score: {preds}%"
110
  return text2
111
 
112
+
113
  def deepfakes_video_predict(input_video):
114
  video_frames = preprocess_video(input_video)
115
+ real_faces_list, fake_faces_list = [], []
116
+
117
  for face in video_frames:
118
+ img_grads = img_model(face)
119
+ img_grads_np = np.squeeze(img_grads.detach().cpu().numpy())
120
+ real_faces_list.append(img_grads_np[0])
121
+ fake_faces_list.append(img_grads_np[1])
122
+
123
+ real_faces_mean = np.mean(real_faces_list)
124
+ fake_faces_mean = np.mean(fake_faces_list)
125
+
126
+ if real_faces_mean > 0.5:
127
+ preds = round(real_faces_mean * 100, 3)
128
+ text2 = f"The video is REAL. \nConfidence score: {preds}%"
129
  else:
130
+ preds = round(fake_faces_mean * 100, 3)
131
+ text2 = f"The video is FAKE. \nConfidence score: {preds}%"
132
  return text2