jacaranda-app / app.py
lily-hust's picture
Update app.py
233cf7e
import streamlit as st
import pandas as pd
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.preprocessing.image import img_to_array
from datasets import load_dataset
def main():
st.title('Jacaranda Identification')
st.markdown("This is a Deep Learning application to identify if a aerial image clip contains Jacaranda trees.\n")
st.markdown('The predicting result will be "Jacaranda", or "Others".')
st.markdown('You can click "Browse files" multiple times until adding all images before generating prediction.\n')
st.markdown('The image clips can look like these examples. The image size can be arbitrary but better not include too much contents.')
run_the_app()
@st.cache_resource()#(allow_output_mutation=True)
def load_model():
# Load the network. Because this is cached it will only happen once.
model = tf.keras.models.load_model('model')
return model
@st.cache_data()
def generate_df():
dict = {'Image file name':[],
'Class name': []
}
df = pd.DataFrame(dict)
return df
@st.cache_data()
def write_df(df, file, cls):
rec = {'Image file name': file.name,
'Class name': cls}
df = pd.concat([df, pd.DataFrame([rec])], ignore_index=True)
return df
@st.cache_data()
def convert_df(df):
return df.to_csv(index=False, encoding='utf-8')
def run_the_app():
dataset = load_dataset('jacaranda', split='train')
st.image(dataset[0:5]['image'])
class_names = ['Jacaranda', 'Others']
model = load_model()
df = generate_df()
uploaded_files = st.file_uploader(
"Upload images",
type="jpg" or 'jpeg' or 'bmp' or 'png' or 'tif',
accept_multiple_files=True)
if uploaded_files:
st.image(uploaded_files, width=100)
if st.button("Clear uploaded images"):
st.empty()
st.experimental_rerun()
if st.button("Generate prediction"):
for file in uploaded_files:
img = Image.open(file)
img_array = img_to_array(img)
img_array = tf.expand_dims(img_array, axis = 0) # Create a batch
processed_image = preprocess_input(img_array)
predictions = model.predict(processed_image)
score = predictions[0]
cls = class_names[np.argmax(score)]
st.markdown("Predicted class of the image {} is : {}".format(file, cls))
df = write_df(df, file, cls)
csv = convert_df(df)
st.download_button("Download the results as CSV",
data = csv,
file_name = "jacaranda_identification.csv")
if __name__ == "__main__":
main()