ybbwcwaps commited on
Commit
9011f2d
1 Parent(s): 2fc658c
Files changed (2) hide show
  1. app.py +17 -4
  2. run.py +8 -5
app.py CHANGED
@@ -1,14 +1,27 @@
1
  import gradio as gr
2
- from run import detect_video
 
 
3
 
4
  def greet(video):
5
  print(video, type(video))
6
- pred = detect_video(video_path=video)
7
  if pred > 0.5:
8
  string = f"Fake: {pred*100:.2f}%"
9
  else:
10
  string = f"Real: {(1-pred)*100:.2f}%"
 
11
  return string
12
 
13
- demo = gr.Interface(fn=greet, inputs=gr.Video(), outputs="text")
14
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from run import get_model, detect_video
3
+
4
+ model = get_model()
5
 
6
  def greet(video):
7
  print(video, type(video))
8
+ pred = detect_video(video_path=video, model=model)
9
  if pred > 0.5:
10
  string = f"Fake: {pred*100:.2f}%"
11
  else:
12
  string = f"Real: {(1-pred)*100:.2f}%"
13
+ print(string)
14
  return string
15
 
16
+ with gr.Blocks() as demo:
17
+ gr.Markdown("# Fake Video Detector")
18
+ with gr.Tabs():
19
+ with gr.TabItem("Video Detect"):
20
+ with gr.Row():
21
+ video_input = gr.Video()
22
+ detect_output = gr.Textbox("Output")
23
+ video_button = gr.Button("detect")
24
+
25
+ video_button.click(greet, inputs=video_input, outputs=detect_output)
26
+
27
+ demo.launch()
run.py CHANGED
@@ -19,24 +19,25 @@ import options
19
  from networks.validator import Validator
20
 
21
 
22
- def detect_video(video_path):
23
  val_opt = options.TestOptions().parse(print_options=False)
24
-
25
  output_dir=os.path.join(val_opt.output, val_opt.name)
26
  os.makedirs(output_dir, exist_ok=True)
 
27
  # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
28
  print(f"working...")
29
 
30
  model = Validator(val_opt)
31
  model.load_state_dict(val_opt.ckpt)
32
  print("ckpt loaded!")
 
 
33
 
34
- # val_loader = create_test_dataloader(val_opt, clip_model = None, transform = model.clip_model.preprocess)
35
  frames, _, _ = read_video(str(video_path), pts_unit='sec')
36
  frames = frames[:16]
37
  frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
38
 
39
-
40
  video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
41
 
42
  with torch.no_grad():
@@ -73,7 +74,9 @@ if __name__ == '__main__':
73
 
74
  # pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
75
 
76
- pred = detect_video(video_path)
 
 
77
  if pred > 0.5:
78
  print(f"Fake: {pred*100:.2f}%")
79
  else:
 
19
  from networks.validator import Validator
20
 
21
 
22
+ def get_model():
23
  val_opt = options.TestOptions().parse(print_options=False)
 
24
  output_dir=os.path.join(val_opt.output, val_opt.name)
25
  os.makedirs(output_dir, exist_ok=True)
26
+
27
  # logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
28
  print(f"working...")
29
 
30
  model = Validator(val_opt)
31
  model.load_state_dict(val_opt.ckpt)
32
  print("ckpt loaded!")
33
+ return model
34
+
35
 
36
+ def detect_video(video_path, model):
37
  frames, _, _ = read_video(str(video_path), pts_unit='sec')
38
  frames = frames[:16]
39
  frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
40
 
 
41
  video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
42
 
43
  with torch.no_grad():
 
74
 
75
  # pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
76
 
77
+ model = get_model()
78
+
79
+ pred = detect_video(video_path, model)
80
  if pred > 0.5:
81
  print(f"Fake: {pred*100:.2f}%")
82
  else: