from timm.models.vision_transformer import VisionTransformer import torch from torchvision import transforms import cv2 def load_vit(weight_path="models/weights/cvit2_deepfake_detection_ep_50.pth"): """ Loads a custom Vision Transformer model that matches the checkpoint structure. """ model = VisionTransformer( img_size=224, patch_size=16, embed_dim=1024, depth=12, num_heads=16, num_classes=1, ) checkpoint = torch.load(weight_path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) model.load_state_dict(state_dict, strict=False) model.eval() return model def process_video(video_path, model): """ Processes a video file and predicts deepfake probabilities for each frame. """ transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) cap = cv2.VideoCapture(video_path) frame_probs = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break # Convert frame to RGB and preprocess frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) input_tensor = transform(frame_rgb).unsqueeze(0) # Predict with torch.no_grad(): prob = torch.sigmoid(model(input_tensor)).item() frame_probs.append(prob) cap.release() return frame_probs