ybbwcwaps
commited on
Commit
•
9011f2d
1
Parent(s):
2fc658c
ui
Browse files
app.py
CHANGED
@@ -1,14 +1,27 @@
|
|
1 |
import gradio as gr
|
2 |
-
from run import detect_video
|
|
|
|
|
3 |
|
4 |
def greet(video):
|
5 |
print(video, type(video))
|
6 |
-
pred = detect_video(video_path=video)
|
7 |
if pred > 0.5:
|
8 |
string = f"Fake: {pred*100:.2f}%"
|
9 |
else:
|
10 |
string = f"Real: {(1-pred)*100:.2f}%"
|
|
|
11 |
return string
|
12 |
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from run import get_model, detect_video
|
3 |
+
|
4 |
+
model = get_model()
|
5 |
|
6 |
def greet(video):
|
7 |
print(video, type(video))
|
8 |
+
pred = detect_video(video_path=video, model=model)
|
9 |
if pred > 0.5:
|
10 |
string = f"Fake: {pred*100:.2f}%"
|
11 |
else:
|
12 |
string = f"Real: {(1-pred)*100:.2f}%"
|
13 |
+
print(string)
|
14 |
return string
|
15 |
|
16 |
+
with gr.Blocks() as demo:
|
17 |
+
gr.Markdown("# Fake Video Detector")
|
18 |
+
with gr.Tabs():
|
19 |
+
with gr.TabItem("Video Detect"):
|
20 |
+
with gr.Row():
|
21 |
+
video_input = gr.Video()
|
22 |
+
detect_output = gr.Textbox("Output")
|
23 |
+
video_button = gr.Button("detect")
|
24 |
+
|
25 |
+
video_button.click(greet, inputs=video_input, outputs=detect_output)
|
26 |
+
|
27 |
+
demo.launch()
|
run.py
CHANGED
@@ -19,24 +19,25 @@ import options
|
|
19 |
from networks.validator import Validator
|
20 |
|
21 |
|
22 |
-
def
|
23 |
val_opt = options.TestOptions().parse(print_options=False)
|
24 |
-
|
25 |
output_dir=os.path.join(val_opt.output, val_opt.name)
|
26 |
os.makedirs(output_dir, exist_ok=True)
|
|
|
27 |
# logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
|
28 |
print(f"working...")
|
29 |
|
30 |
model = Validator(val_opt)
|
31 |
model.load_state_dict(val_opt.ckpt)
|
32 |
print("ckpt loaded!")
|
|
|
|
|
33 |
|
34 |
-
|
35 |
frames, _, _ = read_video(str(video_path), pts_unit='sec')
|
36 |
frames = frames[:16]
|
37 |
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
|
38 |
|
39 |
-
|
40 |
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
|
41 |
|
42 |
with torch.no_grad():
|
@@ -73,7 +74,9 @@ if __name__ == '__main__':
|
|
73 |
|
74 |
# pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
|
75 |
|
76 |
-
|
|
|
|
|
77 |
if pred > 0.5:
|
78 |
print(f"Fake: {pred*100:.2f}%")
|
79 |
else:
|
|
|
19 |
from networks.validator import Validator
|
20 |
|
21 |
|
22 |
+
def get_model():
|
23 |
val_opt = options.TestOptions().parse(print_options=False)
|
|
|
24 |
output_dir=os.path.join(val_opt.output, val_opt.name)
|
25 |
os.makedirs(output_dir, exist_ok=True)
|
26 |
+
|
27 |
# logger = create_logger(output_dir=output_dir, name="FakeVideoDetector")
|
28 |
print(f"working...")
|
29 |
|
30 |
model = Validator(val_opt)
|
31 |
model.load_state_dict(val_opt.ckpt)
|
32 |
print("ckpt loaded!")
|
33 |
+
return model
|
34 |
+
|
35 |
|
36 |
+
def detect_video(video_path, model):
|
37 |
frames, _, _ = read_video(str(video_path), pts_unit='sec')
|
38 |
frames = frames[:16]
|
39 |
frames = frames.permute(0, 3, 1, 2) # (T,H,W,C) -> (T,C,H,W)
|
40 |
|
|
|
41 |
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames])
|
42 |
|
43 |
with torch.no_grad():
|
|
|
74 |
|
75 |
# pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid()
|
76 |
|
77 |
+
model = get_model()
|
78 |
+
|
79 |
+
pred = detect_video(video_path, model)
|
80 |
if pred > 0.5:
|
81 |
print(f"Fake: {pred*100:.2f}%")
|
82 |
else:
|