import torch from torchvision import transforms from PIL import Image from torchvision.models import efficientnet_b0 import torch.nn as nn def load_efficientnet(weight_path="models/weights/efn-b0_LS_27_loss_0.2205.pth"): """ Loads the EfficientNet-B0 model with custom weights. """ model = efficientnet_b0(pretrained=False) # Use torchvision's EfficientNet-B0 # model = modify_efficientnet(model) checkpoint = torch.load(weight_path, map_location="cpu") model.load_state_dict(checkpoint, strict=False) # Load weights model.eval() # Set to evaluation mode return model def modify_efficientnet(model): # Replace the classifier with one that has 1 output for binary classification num_features = model.classifier[-1].in_features model.classifier[-1] = nn.Linear(num_features, 1) return model def predict_efficientnet(model, image): """ Predicts DeepFake probability using EfficientNet. """ preprocess = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) input_tensor = preprocess(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): output = model(input_tensor) print(output) # Assuming binary classification: Use the first logit or aggregate as needed binary_prob = torch.sigmoid(output[:, 0]).item() return binary_prob