David Fischinger
fix versions in requirements.txt
6c58fe6
raw
history blame
4.14 kB
from PIL import Image
import streamlit as st
import cv2
import numpy as np
import os
import tensorflow as tf
from IMVIP_Supplementary_Material.scripts import dfutils #methods used for DF-Net
DESCRIPTION = """# DF-Net
The Digital Forensics Network is designed and trained to detect and locate image manipulations.
More information can be found in this [publication](https://zenodo.org/record/8214996)
#### Select example image or upload your own image:
"""
IMG_SIZE=256
tf.experimental.numpy.experimental_enable_numpy_behavior()
#np.warnings.filterwarnings('error', category=np.VisibleDeprecationWarning)
# function to load models
#@st.session_state better for hugging face?
@st.cache_resource
def load_models():
#load models
model_path1 = "IMVIP_Supplementary_Material/models/model1/"
model_path2 = "IMVIP_Supplementary_Material/models/model2/"
model_M1 = tf.keras.models.load_model("IMVIP_Supplementary_Material/models/model1/model1.h5")
model_M2 = tf.keras.models.load_model("IMVIP_Supplementary_Material/models/model2/model2.h5")
return model_M1, model_M2
model_M1, model_M2 = load_models()
def check_forgery_df(img):
shape_original = img.shape
img = cv2.resize(img, (IMG_SIZE,IMG_SIZE))
x = np.expand_dims( img.astype('float32')/255., axis=0 )
pred1 = model_M1.predict(x, verbose=0)
pred2= model_M2.predict(x, verbose=0)
# # Ensure pred1 and pred2 are numpy arrays before proceeding
# if isinstance(pred1, dict):
# print("pred1 is dict!")
# pred1 = pred1[next(iter(pred1))]
# if isinstance(pred2, dict):
# pred2 = pred2[next(iter(pred2))]
pred = np.max([pred1,pred2], axis=0)
pred = dfutils.create_mask(pred)
pred = pred.reshape(pred.shape[-3:-1])
resized_image = cv2.resize(pred, (shape_original[1],shape_original[0]), interpolation=cv2.INTER_LINEAR)
return resized_image
def evaluate(img):
pre_t = check_forgery_df(img)
st.image(pre_t, caption="White area indicates potential image manipulations.")
def start_evaluation(uploaded_file):
# Convert the file to an opencv image.
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
opencv_image = cv2.imdecode(file_bytes, 1)
reversed_image = opencv_image[:, :, ::-1]
st.image(reversed_image, caption="Input Image")
evaluate(reversed_image)
def start_evaluation_pil_img(pil_image):
# Convert the PIL image to a NumPy array
opencv_image = np.array(pil_image)
# Convert the image from RGB (PIL format) to BGR (OpenCV format)
opencv_image = cv2.cvtColor(opencv_image, cv2.COLOR_RGB2BGR)
# Reverse the color channels back to RGB for display in Streamlit
reversed_image = opencv_image[:, :, ::-1]
st.image(reversed_image, caption="Input Image")
evaluate(reversed_image)
st.markdown(DESCRIPTION)
img_path1 = "example_images/Sp_D_NRD_A_nat0095_art0058_0582"
img_path2 = "example_images/Sp_D_NRN_A_nat0083_arc0080_0445"
#img_path3 = "example_images/Sp_D_NRN_A_ani0088_cha0044_0441"
image_paths = [img_path1+".jpg", img_path2+".jpg"] #, img_path3+".jpg"]
gt_paths = [img_path1+"_gt.png", img_path2+"_gt.png"] #, img_path3+"_gt.png"]
# Display images in a table format
img = None
for idx, image_path in enumerate(image_paths):
cols = st.columns([2, 2, 2, 2]) # Define column widths
# Place the button in the first column
if cols[0].button(f"Select Image {idx+1}", key=idx):
img = Image.open(image_path)
# Place the image in the second column
with cols[1]:
st.image(image_path, use_column_width=True, caption="Example Image "+str(idx+1))
# Place the ground truth in the third column
with cols[2]:
st.image(gt_paths[idx], use_column_width=True, caption="Ground Truth")
if img is not None:
start_evaluation_pil_img(img)
def reset_image_select():
img = None
uploaded_file= None
uploaded_file = st.file_uploader("Please upload an image", type=["jpeg", "jpg", "png"], on_change=reset_image_select)
if (uploaded_file is not None) and (img is None):
start_evaluation(uploaded_file)