import streamlit as st import torch import torchvision.transforms as transforms from transformers import ViTFeatureExtractor, ViTForImageClassification from PIL import Image import requests from PIL import Image import cv2 import numpy as np # Download the pretrained ViT model feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch16-384') vit_model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-384') # Download the pretrained CNN model cnn_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True) # Define the image transform transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # Define the function to predict whether an image is genuine or morphed def predict(image): # Convert the numpy array to PIL Image object image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) # Preprocess the image image = transform(image) image = image.unsqueeze(0) # Predict the class using ViT with torch.no_grad(): viT_output = vit_model(image) viT_probs = torch.nn.functional.softmax(viT_output, dim=1) viT_score, viT_pred = torch.max(viT_probs, 1) # Predict the class using CNN with torch.no_grad(): cnn_output = cnn_model(image) cnn_probs = torch.nn.functional.softmax(cnn_output, dim=1) cnn_score, cnn_pred = torch.max(cnn_probs, 1) # Combine the predictions using a weighted average combined_score = 0.7 * viT_score.item() + 0.3 * cnn_score.item() combined_pred = viT_pred.item() if viT_score.item() > cnn_score.item() else cnn_pred.item() return combined_pred, combined_score # Define the function to restore an image def restore(image): # Apply a median blur to the image image = cv2.medianBlur(image, 5) return image # Define the function to enhance an image def enhance(image): # Increase the contrast of the image lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) l = clahe.apply(l) lab = cv2.merge((l, a, b)) image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) return image # Define the Streamlit app def app(): st.title("Advanced Face Morphing Detection and Restoration") # Upload an image uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) # Display the uploaded image and perform predictions, restoration, and enhancement if uploaded_file is not None: image = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR) # Predict whether the image is genuine or morphed and show the prediction score prediction, score = predict(image) if prediction == 0: st.write("The image is genuine with a score of {:.2f}".format(score)) else: st.write("The image is morphed with a score of {:.2f}".format(score)) # Restore the image restored_image = restore(image) st.image(restored_image, caption="Restored Image", use_column_width=True) # Enhance the image enhanced_image = enhance(image) st.image(enhanced_image, caption="Enhanced Image", use_column_width=True) # Run the Streamlit app if __name__ == '__main__': app()