File size: 3,410 Bytes
0e2ca4b
 
 
 
 
 
 
 
 
bf23a2d
0e2ca4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import streamlit as st
import torch
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
from PIL import Image
import cv2
import numpy as np

# Download the pretrained ViT model
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch16-384')
vit_model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-384')

# Download the pretrained CNN model
cnn_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True)

# Define the image transform
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Define the function to predict whether an image is genuine or morphed
def predict(image):
    # Convert the numpy array to PIL Image object
    image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))

    # Preprocess the image
    image = transform(image)
    image = image.unsqueeze(0)

    # Predict the class using ViT
    with torch.no_grad():
        viT_output = vit_model(image)
        viT_probs = torch.nn.functional.softmax(viT_output, dim=1)
        viT_score, viT_pred = torch.max(viT_probs, 1)

    # Predict the class using CNN
    with torch.no_grad():
        cnn_output = cnn_model(image)
        cnn_probs = torch.nn.functional.softmax(cnn_output, dim=1)
        cnn_score, cnn_pred = torch.max(cnn_probs, 1)

    # Combine the predictions using a weighted average
    combined_score = 0.7 * viT_score.item() + 0.3 * cnn_score.item()
    combined_pred = viT_pred.item() if viT_score.item() > cnn_score.item() else cnn_pred.item()
    return combined_pred, combined_score

# Define the function to restore an image
def restore(image):
    # Apply a median blur to the image
    image = cv2.medianBlur(image, 5)
    return image

# Define the function to enhance an image
def enhance(image):
    # Increase the contrast of the image
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    l = clahe.apply(l)
    lab = cv2.merge((l, a, b))
    image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
    return image

# Define the Streamlit app
def app():
    st.title("Advanced Face Morphing Detection and Restoration")

    # Upload an image
    uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])

    # Display the uploaded image and perform predictions, restoration, and enhancement
    if uploaded_file is not None:
        image = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)

        # Predict whether the image is genuine or morphed and show the prediction score
        prediction, score = predict(image)
        if prediction == 0:
            st.write("The image is genuine with a score of {:.2f}".format(score))
        else:
            st.write("The image is morphed with a score of {:.2f}".format(score))

        # Restore the image
        restored_image = restore(image)
        st.image(restored_image, caption="Restored Image", use_column_width=True)

        # Enhance the image
        enhanced_image = enhance(image)
        st.image(enhanced_image, caption="Enhanced Image", use_column_width=True)

# Run the Streamlit app
if __name__ == '__main__':
    app()