trying-deepfake / model /pred_func.py
tony133777's picture
new
aae9c6b
import os
import numpy as np
import cv2
import torch
from torchvision import transforms
from tqdm import tqdm
from dataset.loader import normalize_data
from .config import load_config
from .genconvit import GenConViT
from decord import VideoReader, cpu
from .face_detection import detect_faces # Import our new function
device = "cuda" if torch.cuda.is_available() else "cpu"
def load_genconvit(config, net, ed_weight, vae_weight, fp16):
model = GenConViT(
config,
ed=ed_weight,
vae=vae_weight,
net=net,
fp16=fp16
)
model.to(device)
model.eval()
if fp16:
model.half()
return model
# Replace face_rec with our new function
def face_rec(frames, p=None, klass=None):
return detect_faces(frames)
def preprocess_frame(frame):
df_tensor = torch.tensor(frame, device=device).float()
df_tensor = df_tensor.permute((0, 3, 1, 2))
for i in range(len(df_tensor)):
df_tensor[i] = normalize_data()["vid"](df_tensor[i] / 255.0)
return df_tensor
def pred_vid(df, model):
with torch.no_grad():
return max_prediction_value(torch.sigmoid(model(df).squeeze()))
def max_prediction_value(y_pred):
# Finds the index and value of the maximum prediction value.
mean_val = torch.mean(y_pred, dim=0)
return (
torch.argmax(mean_val).item(),
mean_val[0].item()
if mean_val[0] > mean_val[1]
else abs(1 - mean_val[1]).item(),
)
def real_or_fake(prediction):
return {0: "REAL", 1: "FAKE"}[prediction ^ 1]
def extract_frames(video_file, frames_nums=15):
vr = VideoReader(video_file, ctx=cpu(0))
step_size = max(1, len(vr) // frames_nums) # Calculate the step size between frames
return vr.get_batch(
list(range(0, len(vr), step_size))[:frames_nums]
).asnumpy() # seek frames with step_size
def df_face(vid, num_frames, net):
img = extract_frames(vid, num_frames)
face, count = face_rec(img)
return preprocess_frame(face) if count > 0 else []
def is_video(vid):
print('IS FILE', os.path.isfile(vid))
return os.path.isfile(vid) and vid.endswith(
tuple([".avi", ".mp4", ".mpg", ".mpeg", ".mov"])
)
def set_result():
return {
"video": {
"name": [],
"pred": [],
"klass": [],
"pred_label": [],
"correct_label": [],
}
}
def store_result(
result, filename, y, y_val, klass, correct_label=None, compression=None
):
result["video"]["name"].append(filename)
result["video"]["pred"].append(y_val)
result["video"]["klass"].append(klass.lower())
result["video"]["pred_label"].append(real_or_fake(y))
if correct_label is not None:
result["video"]["correct_label"].append(correct_label)
if compression is not None:
result["video"]["compression"].append(compression)
return result