rushidarge's picture
Update app.py
0f31dfe
raw
history blame
No virus
1.53 kB
import cv2
import time
from PIL import Image, ImageOps
import numpy as np
import streamlit as st
import tensorflow as tf
@st.cache(allow_output_mutation=True)
def load_model():
model=tf.keras.models.load_model('MN21.h5')
return model
with st.spinner('Model is being loaded..'):
model=load_model()
st.write("""# SaferNet with AI""")
file = st.file_uploader("Please upload an image to classify", type=["jpg", "png", "jpeg"])
def import_and_predict(image_data, model):
size = (224,224)
image = ImageOps.fit(image_data, size, Image.ANTIALIAS)
image = np.asarray(image)
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#img_resize = (cv2.resize(img, dsize=(75, 75), interpolation=cv2.INTER_CUBIC))/255.
img_reshape = np.reshape(img,(1,224,224,3))
st.write(img_reshape.shape)
start = time.time()
prediction = model.predict(img_reshape)
end = time.time()
time_take = end - start
return prediction, time_take
if file is None:
st.text("Please upload an image file")
else:
image = Image.open(file)
st.image(image, use_column_width=True)
predictions, time_take = import_and_predict(image, model)
st.write("Time taken to predict is ", time_take, "second")
st.write(predictions)
st.write(predictions.shape)
st.write('0,0',predictions[0][0])
st.write('0,1',predictions[0][1])
if predictions[0][0] < 0.5:
st.write("This is SFW image :sunglasses:")
else:
st.write("This is NSFW image")