lily-hust's picture
Update app.py
6d08ce3
raw
history blame
1.75 kB
import streamlit as st
import time
import cv2
import pandas
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
st.title('Palm Identification')
st.markdown("This is a Deep Learning application to identify if a satellite image clip contains Palm trees.\n")
st.markdown('The predicting result will be "Palm", or "Others".')
st.markdown('You can click "Browse files" multiple times until adding all images before generating prediction.\n')
img_height = 224
img_width = 224
class_names = ['Palm', 'Others']
model = tf.keras.models.load_model('model')
state = st.session_state
if 'file_uploader_key' not in state:
state['file_uploader_key'] = 0
if "uploaded_files" not in state:
state["uploaded_files"] = []
uploaded_files = st.file_uploader(
"Upload images",
type="jpg" or 'jpeg' or 'bmp' or 'png' or 'tif',
accept_multiple_files=True,
key=state['file_uploader_key'])
if uploaded_files:
state["uploaded_files"] = uploaded_files
if st.button("Clear all"):
state["file_uploader_key"] += 1
time.sleep(.5)
st.experimental_rerun()
if st.button("Generate prediction"):
for file in uploaded_files:
img = Image.open(file)
img_array = img_to_array(img)
img_array = tf.expand_dims(img_array, axis = 0) # Create a batch
processed_image = preprocess_input(img_array)
predictions = model.predict(processed_image)
score = predictions[0]
st.markdown("Predicted class of the image {} is : {}".format(file, class_names[np.argmax(score)]))