File size: 2,564 Bytes
d7f12b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch
from torchvision import transforms
from PIL import Image
import torch.nn as nn
import torch
import torch.nn as nn
from torchvision.models import resnet50
from utils.preprocessing import *
from models.networks.drn_seg import DRNSub
from utils.preprocessing import generate_local_image


# Define the transform for image preprocessing
tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

def load_fal_detector(model_path='models/weights/global.pth', gpu_id=-1):
    """
    Load the Photoshop FAL detector model with the specified weights.

    Args:
        model_path (str): Path to the model weights file.
        gpu_id (int): GPU ID to use; set -1 for CPU.

    Returns:
        torch.nn.Module: The loaded model in evaluation mode.
    """
    if torch.cuda.is_available() and gpu_id != -1:
        device = f'cuda:{gpu_id}'
    else:
        device = 'cpu'

    # Load the model architecture
    model = DRNSub(1)

    # Load state dictionary from the specified path
    state_dict = torch.load(model_path, map_location='cpu', weights_only=True)
    model.load_state_dict(state_dict['model'])

    # Move model to device and set to evaluation mode
    model.to(device)
    model.device = device
    model.eval()

    return model


def predict_fal_detector(model, image, no_crop=True):
    """
    Predict whether an image has been modified using Photoshop FAL.

    Args:
        model (torch.nn.Module): Loaded model.
        img_path (str): Path to the image file.
        no_crop (bool): If True, skip face detection and process the full image.
        face_detector_path (str): Path to the face detector model.

    Returns:
        float: The probability of the image being modified by Photoshop FAL.
    """
    # Open the image and preprocess
    if no_crop:
        face = image
    else:
        # # Perform face detection
        # faces = face_detection(img_path, verbose=False, model_file=face_detector_path)
        # if len(faces) == 0:
        #     print("No face detected by the face detector.")
        #     return None
        # face, box = faces[0]
        pass

    local_image = generate_local_image(image)
    face = local_image
    # Resize and transform the face image
    face = resize_shorter_side(face, 400)[0]
    face_tens = tf(face).to(model.device)

    # Perform inference
    with torch.no_grad():
        fake_prob = model(face_tens.unsqueeze(0))[0].sigmoid().cpu().item()

    return fake_prob*100