deepsaif / models /fal_detector.py
22GC22's picture
Upload 12 files
d7f12b9 verified
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch
import torch.nn as nn
from torchvision.models import resnet50
from utils.preprocessing import *
from models.networks.drn_seg import DRNSub
from utils.preprocessing import generate_local_image
# Define the transform for image preprocessing
tf = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
def load_fal_detector(model_path='models/weights/global.pth', gpu_id=-1):
"""
Load the Photoshop FAL detector model with the specified weights.
Args:
model_path (str): Path to the model weights file.
gpu_id (int): GPU ID to use; set -1 for CPU.
Returns:
torch.nn.Module: The loaded model in evaluation mode.
"""
if torch.cuda.is_available() and gpu_id != -1:
device = f'cuda:{gpu_id}'
else:
device = 'cpu'
# Load the model architecture
model = DRNSub(1)
# Load state dictionary from the specified path
state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
model.load_state_dict(state_dict['model'])
# Move model to device and set to evaluation mode
model.to(device)
model.device = device
model.eval()
return model
def predict_fal_detector(model, image, no_crop=True):
"""
Predict whether an image has been modified using Photoshop FAL.
Args:
model (torch.nn.Module): Loaded model.
img_path (str): Path to the image file.
no_crop (bool): If True, skip face detection and process the full image.
face_detector_path (str): Path to the face detector model.
Returns:
float: The probability of the image being modified by Photoshop FAL.
"""
# Open the image and preprocess
if no_crop:
face = image
else:
# # Perform face detection
# faces = face_detection(img_path, verbose=False, model_file=face_detector_path)
# if len(faces) == 0:
# print("No face detected by the face detector.")
# return None
# face, box = faces[0]
pass
local_image = generate_local_image(image)
face = local_image
# Resize and transform the face image
face = resize_shorter_side(face, 400)[0]
face_tens = tf(face).to(model.device)
# Perform inference
with torch.no_grad():
fake_prob = model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()
return fake_prob*100