|
from timm.models.vision_transformer import VisionTransformer |
|
import torch |
|
from torchvision import transforms |
|
import cv2 |
|
|
|
def load_vit(weight_path="models/weights/cvit2_deepfake_detection_ep_50.pth"): |
|
""" |
|
Loads a custom Vision Transformer model that matches the checkpoint structure. |
|
""" |
|
model = VisionTransformer( |
|
img_size=224, |
|
patch_size=16, |
|
embed_dim=1024, |
|
depth=12, |
|
num_heads=16, |
|
num_classes=1, |
|
) |
|
checkpoint = torch.load(weight_path, map_location="cpu") |
|
state_dict = checkpoint.get("state_dict", checkpoint) |
|
model.load_state_dict(state_dict, strict=False) |
|
model.eval() |
|
return model |
|
|
|
def process_video(video_path, model): |
|
""" |
|
Processes a video file and predicts deepfake probabilities for each frame. |
|
""" |
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
frame_probs = [] |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
input_tensor = transform(frame_rgb).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
prob = torch.sigmoid(model(input_tensor)).item() |
|
frame_probs.append(prob) |
|
|
|
cap.release() |
|
return frame_probs |
|
|