|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
from torchvision.models import efficientnet_b0 |
|
import torch.nn as nn |
|
|
|
def load_efficientnet(weight_path="models/weights/efn-b0_LS_27_loss_0.2205.pth"): |
|
""" |
|
Loads the EfficientNet-B0 model with custom weights. |
|
""" |
|
model = efficientnet_b0(pretrained=False) |
|
|
|
checkpoint = torch.load(weight_path, map_location="cpu") |
|
model.load_state_dict(checkpoint, strict=False) |
|
model.eval() |
|
return model |
|
|
|
|
|
|
|
def modify_efficientnet(model): |
|
|
|
num_features = model.classifier[-1].in_features |
|
model.classifier[-1] = nn.Linear(num_features, 1) |
|
return model |
|
|
|
|
|
def predict_efficientnet(model, image): |
|
""" |
|
Predicts DeepFake probability using EfficientNet. |
|
""" |
|
preprocess = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
input_tensor = preprocess(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
print(output) |
|
|
|
binary_prob = torch.sigmoid(output[:, 0]).item() |
|
return binary_prob |
|
|