deepsaif / models /vit.py
22GC22's picture
Upload 12 files
d7f12b9 verified
from timm.models.vision_transformer import VisionTransformer
import torch
from torchvision import transforms
import cv2
def load_vit(weight_path="models/weights/cvit2_deepfake_detection_ep_50.pth"):
"""
Loads a custom Vision Transformer model that matches the checkpoint structure.
"""
model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=1024,
depth=12,
num_heads=16,
num_classes=1,
)
checkpoint = torch.load(weight_path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict(state_dict, strict=False)
model.eval()
return model
def process_video(video_path, model):
"""
Processes a video file and predicts deepfake probabilities for each frame.
"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
cap = cv2.VideoCapture(video_path)
frame_probs = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Convert frame to RGB and preprocess
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
input_tensor = transform(frame_rgb).unsqueeze(0)
# Predict
with torch.no_grad():
prob = torch.sigmoid(model(input_tensor)).item()
frame_probs.append(prob)
cap.release()
return frame_probs