oo01's picture
Upload 6 files
3b237c2 verified
import torch
import torch.nn as nn
import timm
from pathlib import Path
import logging
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EfficientNetDeepFakeDetector(nn.Module):
"""Frame-level EfficientNet-B0 with temporal mean-pooling."""
FEAT_DIM = 1280
def __init__(self, dropout: float = 0.4):
super().__init__()
# Backbone
backbone = timm.create_model(
'efficientnet_b0',
pretrained=False,
num_classes=0,
global_pool='avg'
)
# Freeze BatchNorm layers
for m in backbone.modules():
if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m.eval()
for p in m.parameters():
p.requires_grad = False
self.backbone = backbone
# Classifier head
self.head = nn.Sequential(
nn.LayerNorm(self.FEAT_DIM),
nn.Dropout(dropout),
nn.Linear(self.FEAT_DIM, 256),
nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(256, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
feat = self.backbone(x)
feat = feat.view(B, T, self.FEAT_DIM)
feat = feat.mean(dim=1)
logit = self.head(feat).squeeze(-1)
return logit
class DeepFakeModel:
def __init__(self, model_path: str, device: str = "cpu"):
self.device = torch.device(device)
self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device)
self._load_model(model_path)
self.model.eval()
logger.info(f"Model loaded on {self.device}")
def _load_model(self, model_path: str):
"""Load model checkpoint from file"""
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model file not found at {model_path}")
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
# Handle different checkpoint formats
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
epoch = checkpoint.get('epoch', 'unknown')
val_f1 = checkpoint.get('val_f1_macro', 'unknown')
logger.info(f"Loaded checkpoint from epoch {epoch} (val_f1={val_f1})")
else:
# If checkpoint is just the state dict
self.model.load_state_dict(checkpoint)
logger.info("Loaded model state dict")
@torch.no_grad()
def predict(self, video_tensor: torch.Tensor, threshold: float = 0.5) -> dict:
"""
Predict if video is real or fake.
Args:
video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W)
threshold: Decision threshold (default: 0.5)
Returns:
dict with prediction, confidence, and probabilities
"""
if video_tensor.dim() == 4:
video_tensor = video_tensor.unsqueeze(0)
video_tensor = video_tensor.to(self.device)
logit = self.model(video_tensor)
prob = torch.sigmoid(logit).item()
# prob = P(REAL), because training used label 1=REAL, 0=FAKE
prediction = "REAL" if prob >= threshold else "FAKE"
confidence = prob if prediction == "REAL" else 1 - prob
return {
"prediction": prediction,
"confidence": round(confidence, 4),
"probability_real": round(prob, 4),
"probability_fake": round(1 - prob, 4),
"threshold": threshold
}
@torch.no_grad()
def predict_from_video_path(self, video_path: str, threshold: float = 0.5) -> dict:
"""
Convenience method to predict directly from video file path.
Args:
video_path: Path to video file
threshold: Decision threshold
Returns:
Prediction result dictionary
"""
from .utils import video_to_tensor
video_tensor = video_to_tensor(
video_path,
num_frames=16,
img_size=224
)
return self.predict(video_tensor, threshold)