AbdullahMd12's picture
Update app.py
0e2ca4b
raw
history blame
No virus
3.41 kB
import streamlit as st
import torch
import torchvision.transforms as transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
import requests
from PIL import Image
import cv2
import numpy as np
# Download the pretrained ViT model
feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-large-patch16-384')
vit_model = ViTForImageClassification.from_pretrained('google/vit-large-patch16-384')
# Download the pretrained CNN model
cnn_model = torch.hub.load('pytorch/vision:v0.9.0', 'resnet50', pretrained=True)
# Define the image transform
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Define the function to predict whether an image is genuine or morphed
def predict(image):
# Convert the numpy array to PIL Image object
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
# Preprocess the image
image = transform(image)
image = image.unsqueeze(0)
# Predict the class using ViT
with torch.no_grad():
viT_output = vit_model(image)
viT_probs = torch.nn.functional.softmax(viT_output, dim=1)
viT_score, viT_pred = torch.max(viT_probs, 1)
# Predict the class using CNN
with torch.no_grad():
cnn_output = cnn_model(image)
cnn_probs = torch.nn.functional.softmax(cnn_output, dim=1)
cnn_score, cnn_pred = torch.max(cnn_probs, 1)
# Combine the predictions using a weighted average
combined_score = 0.7 * viT_score.item() + 0.3 * cnn_score.item()
combined_pred = viT_pred.item() if viT_score.item() > cnn_score.item() else cnn_pred.item()
return combined_pred, combined_score
# Define the function to restore an image
def restore(image):
# Apply a median blur to the image
image = cv2.medianBlur(image, 5)
return image
# Define the function to enhance an image
def enhance(image):
# Increase the contrast of the image
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
l = clahe.apply(l)
lab = cv2.merge((l, a, b))
image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
return image
# Define the Streamlit app
def app():
st.title("Advanced Face Morphing Detection and Restoration")
# Upload an image
uploaded_file = st.file_uploader("Choose an image", type=["jpg", "jpeg", "png"])
# Display the uploaded image and perform predictions, restoration, and enhancement
if uploaded_file is not None:
image = cv2.imdecode(np.fromstring(uploaded_file.read(), np.uint8), cv2.IMREAD_COLOR)
# Predict whether the image is genuine or morphed and show the prediction score
prediction, score = predict(image)
if prediction == 0:
st.write("The image is genuine with a score of {:.2f}".format(score))
else:
st.write("The image is morphed with a score of {:.2f}".format(score))
# Restore the image
restored_image = restore(image)
st.image(restored_image, caption="Restored Image", use_column_width=True)
# Enhance the image
enhanced_image = enhance(image)
st.image(enhanced_image, caption="Enhanced Image", use_column_width=True)
# Run the Streamlit app
if __name__ == '__main__':
app()