Spaces:
Running
Running
import os | |
import numpy as np | |
import cv2 | |
import torch | |
from torchvision import transforms | |
from tqdm import tqdm | |
from dataset.loader import normalize_data | |
from .config import load_config | |
from .genconvit import GenConViT | |
from decord import VideoReader, cpu | |
from .face_detection import detect_faces # Import our new function | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
def load_genconvit(config, net, ed_weight, vae_weight, fp16): | |
model = GenConViT( | |
config, | |
ed=ed_weight, | |
vae=vae_weight, | |
net=net, | |
fp16=fp16 | |
) | |
model.to(device) | |
model.eval() | |
if fp16: | |
model.half() | |
return model | |
# Replace face_rec with our new function | |
def face_rec(frames, p=None, klass=None): | |
return detect_faces(frames) | |
def preprocess_frame(frame): | |
df_tensor = torch.tensor(frame, device=device).float() | |
df_tensor = df_tensor.permute((0, 3, 1, 2)) | |
for i in range(len(df_tensor)): | |
df_tensor[i] = normalize_data()["vid"](df_tensor[i] / 255.0) | |
return df_tensor | |
def pred_vid(df, model): | |
with torch.no_grad(): | |
return max_prediction_value(torch.sigmoid(model(df).squeeze())) | |
def max_prediction_value(y_pred): | |
# Finds the index and value of the maximum prediction value. | |
mean_val = torch.mean(y_pred, dim=0) | |
return ( | |
torch.argmax(mean_val).item(), | |
mean_val[0].item() | |
if mean_val[0] > mean_val[1] | |
else abs(1 - mean_val[1]).item(), | |
) | |
def real_or_fake(prediction): | |
return {0: "REAL", 1: "FAKE"}[prediction ^ 1] | |
def extract_frames(video_file, frames_nums=15): | |
vr = VideoReader(video_file, ctx=cpu(0)) | |
step_size = max(1, len(vr) // frames_nums) # Calculate the step size between frames | |
return vr.get_batch( | |
list(range(0, len(vr), step_size))[:frames_nums] | |
).asnumpy() # seek frames with step_size | |
def df_face(vid, num_frames, net): | |
img = extract_frames(vid, num_frames) | |
face, count = face_rec(img) | |
return preprocess_frame(face) if count > 0 else [] | |
def is_video(vid): | |
print('IS FILE', os.path.isfile(vid)) | |
return os.path.isfile(vid) and vid.endswith( | |
tuple([".avi", ".mp4", ".mpg", ".mpeg", ".mov"]) | |
) | |
def set_result(): | |
return { | |
"video": { | |
"name": [], | |
"pred": [], | |
"klass": [], | |
"pred_label": [], | |
"correct_label": [], | |
} | |
} | |
def store_result( | |
result, filename, y, y_val, klass, correct_label=None, compression=None | |
): | |
result["video"]["name"].append(filename) | |
result["video"]["pred"].append(y_val) | |
result["video"]["klass"].append(klass.lower()) | |
result["video"]["pred_label"].append(real_or_fake(y)) | |
if correct_label is not None: | |
result["video"]["correct_label"].append(correct_label) | |
if compression is not None: | |
result["video"]["compression"].append(compression) | |
return result | |