SixOpen commited on
Commit
d72b356
1 Parent(s): ed84fba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -12,15 +12,14 @@ from Scripts.model import create_cam, create_model
12
  from Scripts.preprocess import crop_face, extract_face, extract_frames
13
  from Scripts.ca_generator import get_augs
14
 
15
- warnings.filterwarnings('ignore')
16
-
17
 
 
18
 
19
- device = torch.device('cpu')
20
 
21
  sbcl = create_model("Weights/weights.tar")
22
 
23
-
24
  face_detector = get_model("resnet50_2020-07-20", max_size=1024, device=device)
25
  face_detector.eval()
26
 
@@ -37,9 +36,8 @@ dlib_face_detector = dlib.get_frontal_face_detector()
37
  dlib_face_predictor = dlib.shape_predictor(
38
  'Weights/shape_predictor_81_face_landmarks.dat')
39
 
40
-
41
  def predict_image(inp):
42
-
43
  face_list = extract_face(inp, face_detector)
44
 
45
  if len(face_list) == 0:
@@ -56,9 +54,8 @@ def predict_image(inp):
56
 
57
  return confidences, cam_image
58
 
59
-
60
  def predict_video(inp):
61
-
62
  face_list, idx_list = extract_frames(inp, 10, face_detector)
63
 
64
  with torch.no_grad():
@@ -84,8 +81,6 @@ def predict_video(inp):
84
 
85
  return {'Real': 1-pred, 'Fake': pred}, cam_image
86
 
87
-
88
-
89
  with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css="""
90
  @import url('https://fonts.googleapis.com/css?family=Source+Code+Pro:200');
91
  #custom_header {
@@ -186,4 +181,4 @@ with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css=""
186
  btn_video.click(predict_video, inputs=input_video, outputs=[label_probs_video, output_image_video], api_name="/predict_video")
187
 
188
  if __name__ == "__main__":
189
- demo.launch()
 
12
  from Scripts.preprocess import crop_face, extract_face, extract_frames
13
  from Scripts.ca_generator import get_augs
14
 
15
+ import spaces
 
16
 
17
+ warnings.filterwarnings('ignore')
18
 
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
 
21
  sbcl = create_model("Weights/weights.tar")
22
 
 
23
  face_detector = get_model("resnet50_2020-07-20", max_size=1024, device=device)
24
  face_detector.eval()
25
 
 
36
  dlib_face_predictor = dlib.shape_predictor(
37
  'Weights/shape_predictor_81_face_landmarks.dat')
38
 
39
+ @spaces.GPU
40
  def predict_image(inp):
 
41
  face_list = extract_face(inp, face_detector)
42
 
43
  if len(face_list) == 0:
 
54
 
55
  return confidences, cam_image
56
 
57
+ @spaces.GPU
58
  def predict_video(inp):
 
59
  face_list, idx_list = extract_frames(inp, 10, face_detector)
60
 
61
  with torch.no_grad():
 
81
 
82
  return {'Real': 1-pred, 'Fake': pred}, cam_image
83
 
 
 
84
  with gr.Blocks(title="Deepfake Detection CL", theme='upsatwal/mlsc_tiet', css="""
85
  @import url('https://fonts.googleapis.com/css?family=Source+Code+Pro:200');
86
  #custom_header {
 
181
  btn_video.click(predict_video, inputs=input_video, outputs=[label_probs_video, output_image_video], api_name="/predict_video")
182
 
183
  if __name__ == "__main__":
184
+ demo.launch()