Spaces:
Build error
Build error
import streamlit as st | |
import torch | |
import torchvision.transforms as transforms | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
from PIL import Image | |
import requests | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
# Download the pretrained ViT model | |
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch16-384') | |
vit_model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-384') | |
# Download the pretrained CNN model | |
cnn_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True) | |
# Define the image transform | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
]) | |
# Define the function to predict whether an image is genuine or morphed | |
def predict(image): | |
# Convert the numpy array to PIL Image object | |
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) | |
# Preprocess the image | |
image = transform(image) | |
image = image.unsqueeze(0) | |
# Predict the class using ViT | |
with torch.no_grad(): | |
viT_output = vit_model(image) | |
viT_probs = torch.nn.functional.softmax(viT_output, dim=1) | |
viT_score, viT_pred = torch.max(viT_probs, 1) | |
# Predict the class using CNN | |
with torch.no_grad(): | |
cnn_output = cnn_model(image) | |
cnn_probs = torch.nn.functional.softmax(cnn_output, dim=1) | |
cnn_score, cnn_pred = torch.max(cnn_probs, 1) | |
# Combine the predictions using a weighted average | |
combined_score = 0.7 * viT_score.item() + 0.3 * cnn_score.item() | |
combined_pred = viT_pred.item() if viT_score.item() > cnn_score.item() else cnn_pred.item() | |
return combined_pred, combined_score | |
# Define the function to restore an image | |
def restore(image): | |
# Apply a median blur to the image | |
image = cv2.medianBlur(image, 5) | |
return image | |
# Define the function to enhance an image | |
def enhance(image): | |
# Increase the contrast of the image | |
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) | |
l, a, b = cv2.split(lab) | |
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8)) | |
l = clahe.apply(l) | |
lab = cv2.merge((l, a, b)) | |
image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) | |
return image | |
# Define the Streamlit app | |
def app(): | |
st.title("Advanced Face Morphing Detection and Restoration") | |
# Upload an image | |
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"]) | |
# Display the uploaded image and perform predictions, restoration, and enhancement | |
if uploaded_file is not None: | |
image = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR) | |
# Predict whether the image is genuine or morphed and show the prediction score | |
prediction, score = predict(image) | |
if prediction == 0: | |
st.write("The image is genuine with a score of {:.2f}".format(score)) | |
else: | |
st.write("The image is morphed with a score of {:.2f}".format(score)) | |
# Restore the image | |
restored_image = restore(image) | |
st.image(restored_image, caption="Restored Image", use_column_width=True) | |
# Enhance the image | |
enhanced_image = enhance(image) | |
st.image(enhanced_image, caption="Enhanced Image", use_column_width=True) | |
# Run the Streamlit app | |
if __name__ == '__main__': | |
app() |