import torch from torchvision import transforms from PIL import Image import torch.nn as nn import torch import torch.nn as nn from torchvision.models import resnet50 from utils.preprocessing import * from models.networks.drn_seg import DRNSub from utils.preprocessing import generate_local_image # Define the transform for image preprocessing tf = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def load_fal_detector(model_path='models/weights/global.pth', gpu_id=-1): """ Load the Photoshop FAL detector model with the specified weights. Args: model_path (str): Path to the model weights file. gpu_id (int): GPU ID to use; set -1 for CPU. Returns: torch.nn.Module: The loaded model in evaluation mode. """ if torch.cuda.is_available() and gpu_id != -1: device = f'cuda:{gpu_id}' else: device = 'cpu' # Load the model architecture model = DRNSub(1) # Load state dictionary from the specified path state_dict = torch.load(model_path, map_location='cpu', weights_only=True) model.load_state_dict(state_dict['model']) # Move model to device and set to evaluation mode model.to(device) model.device = device model.eval() return model def predict_fal_detector(model, image, no_crop=True): """ Predict whether an image has been modified using Photoshop FAL. Args: model (torch.nn.Module): Loaded model. img_path (str): Path to the image file. no_crop (bool): If True, skip face detection and process the full image. face_detector_path (str): Path to the face detector model. Returns: float: The probability of the image being modified by Photoshop FAL. """ # Open the image and preprocess if no_crop: face = image else: # # Perform face detection # faces = face_detection(img_path, verbose=False, model_file=face_detector_path) # if len(faces) == 0: # print("No face detected by the face detector.") # return None # face, box = faces[0] pass local_image = generate_local_image(image) face = local_image # Resize and transform the face image face = resize_shorter_side(face, 400)[0] face_tens = tf(face).to(model.device) # Perform inference with torch.no_grad(): fake_prob = model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item() return fake_prob*100