|
import os |
|
import torch |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
from torchvision.io import read_video |
|
import torch.utils.data |
|
import numpy as np |
|
from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score |
|
import pickle |
|
from tqdm import tqdm |
|
from datetime import datetime |
|
from copy import deepcopy |
|
from dataset_paths import DATASET_PATHS |
|
import random |
|
|
|
from datasets import create_test_dataloader |
|
from utils.logger import create_logger |
|
import options |
|
from networks.validator import Validator |
|
|
|
|
|
def get_model(): |
|
val_opt = options.TestOptions().parse(print_options=False) |
|
output_dir=os.path.join(val_opt.output, val_opt.name) |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print(f"working...") |
|
|
|
model = Validator(val_opt) |
|
model.load_state_dict(val_opt.ckpt) |
|
print("ckpt loaded!") |
|
return model |
|
|
|
|
|
def detect_video(video_path, model): |
|
frames, _, _ = read_video(str(video_path), pts_unit='sec') |
|
frames = frames[:16] |
|
frames = frames.permute(0, 3, 1, 2) |
|
|
|
video_frames = torch.cat([model.clip_model.preprocess(TF.to_pil_image(frame)).unsqueeze(0) for frame in frames]) |
|
|
|
with torch.no_grad(): |
|
model.set_input([torch.as_tensor(video_frames), torch.tensor([0])]) |
|
|
|
pred = model.model(model.input).view(-1).unsqueeze(1).sigmoid() |
|
|
|
return pred[0].item() |
|
|
|
|
|
if __name__ == '__main__': |
|
video_path = '../../dataset/MSRVTT/videos/all/video1.mp4' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = get_model() |
|
|
|
pred = detect_video(video_path, model) |
|
if pred > 0.5: |
|
print(f"Fake: {pred*100:.2f}%") |
|
else: |
|
print(f"Real: {(1-pred)*100:.2f}%") |
|
|