|
import torch |
|
from torchvision import transforms |
|
from PIL import Image |
|
import torch.nn as nn |
|
import torch |
|
import torch.nn as nn |
|
from torchvision.models import resnet50 |
|
from utils.preprocessing import * |
|
from models.networks.drn_seg import DRNSub |
|
from utils.preprocessing import generate_local_image |
|
|
|
|
|
|
|
tf = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
def load_fal_detector(model_path='models/weights/global.pth', gpu_id=-1): |
|
""" |
|
Load the Photoshop FAL detector model with the specified weights. |
|
|
|
Args: |
|
model_path (str): Path to the model weights file. |
|
gpu_id (int): GPU ID to use; set -1 for CPU. |
|
|
|
Returns: |
|
torch.nn.Module: The loaded model in evaluation mode. |
|
""" |
|
if torch.cuda.is_available() and gpu_id != -1: |
|
device = f'cuda:{gpu_id}' |
|
else: |
|
device = 'cpu' |
|
|
|
|
|
model = DRNSub(1) |
|
|
|
|
|
state_dict = torch.load(model_path, map_location='cpu', weights_only=True) |
|
model.load_state_dict(state_dict['model']) |
|
|
|
|
|
model.to(device) |
|
model.device = device |
|
model.eval() |
|
|
|
return model |
|
|
|
|
|
def predict_fal_detector(model, image, no_crop=True): |
|
""" |
|
Predict whether an image has been modified using Photoshop FAL. |
|
|
|
Args: |
|
model (torch.nn.Module): Loaded model. |
|
img_path (str): Path to the image file. |
|
no_crop (bool): If True, skip face detection and process the full image. |
|
face_detector_path (str): Path to the face detector model. |
|
|
|
Returns: |
|
float: The probability of the image being modified by Photoshop FAL. |
|
""" |
|
|
|
if no_crop: |
|
face = image |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
local_image = generate_local_image(image) |
|
face = local_image |
|
|
|
face = resize_shorter_side(face, 400)[0] |
|
face_tens = tf(face).to(model.device) |
|
|
|
|
|
with torch.no_grad(): |
|
fake_prob = model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item() |
|
|
|
return fake_prob*100 |
|
|
|
|