deepsaif / models /efficientnet.py
22GC22's picture
Upload 12 files
d7f12b9 verified
import torch
from torchvision import transforms
from PIL import Image
from torchvision.models import efficientnet_b0
import torch.nn as nn
def load_efficientnet(weight_path="models/weights/efn-b0_LS_27_loss_0.2205.pth"):
"""
Loads the EfficientNet-B0 model with custom weights.
"""
model = efficientnet_b0(pretrained=False) # Use torchvision's EfficientNet-B0
# model = modify_efficientnet(model)
checkpoint = torch.load(weight_path, map_location="cpu")
model.load_state_dict(checkpoint, strict=False) # Load weights
model.eval() # Set to evaluation mode
return model
def modify_efficientnet(model):
# Replace the classifier with one that has 1 output for binary classification
num_features = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_features, 1)
return model
def predict_efficientnet(model, image):
"""
Predicts DeepFake probability using EfficientNet.
"""
preprocess = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
output = model(input_tensor)
print(output)
# Assuming binary classification: Use the first logit or aggregate as needed
binary_prob = torch.sigmoid(output[:, 0]).item()
return binary_prob