File size: 2,564 Bytes
d7f12b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch
import torch.nn as nn
from torchvision.models import resnet50
from utils.preprocessing import *
from models.networks.drn_seg import DRNSub
from utils.preprocessing import generate_local_image
# Define the transform for image preprocessing
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_fal_detector(model_path='models/weights/global.pth', gpu_id=-1):
"""
Load the Photoshop FAL detector model with the specified weights.
Args:
model_path (str): Path to the model weights file.
gpu_id (int): GPU ID to use; set -1 for CPU.
Returns:
torch.nn.Module: The loaded model in evaluation mode.
"""
if torch.cuda.is_available() and gpu_id != -1:
device = f'cuda:{gpu_id}'
else:
device = 'cpu'
# Load the model architecture
model = DRNSub(1)
# Load state dictionary from the specified path
state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict['model'])
# Move model to device and set to evaluation mode
model.to(device)
model.device = device
model.eval()
return model
def predict_fal_detector(model, image, no_crop=True):
"""
Predict whether an image has been modified using Photoshop FAL.
Args:
model (torch.nn.Module): Loaded model.
img_path (str): Path to the image file.
no_crop (bool): If True, skip face detection and process the full image.
face_detector_path (str): Path to the face detector model.
Returns:
float: The probability of the image being modified by Photoshop FAL.
"""
# Open the image and preprocess
if no_crop:
face = image
else:
# # Perform face detection
# faces = face_detection(img_path, verbose=False, model_file=face_detector_path)
# if len(faces) == 0:
# print("No face detected by the face detector.")
# return None
# face, box = faces[0]
pass
local_image = generate_local_image(image)
face = local_image
# Resize and transform the face image
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(model.device)
# Perform inference
with torch.no_grad():
fake_prob = model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
return fake_prob*100
|