KaburaJ's picture
Update app.py
8fbfe88
raw
history blame contribute delete
No virus
3.81 kB
import base64
import streamlit as st
from PIL import Image
import numpy as np
from keras.models import model_from_json
import subprocess
import os
import tensorflow as tf
from keras.applications.imagenet_utils import preprocess_input
st.markdown('<h1 style="color:white;">Image Classification App</h1>', unsafe_allow_html=True)
st.markdown('<h2 style="color:white;">for classifying **zebras** and **horses**</h2>', unsafe_allow_html=True)
st.cache(allow_output_mutation=True)
def get_base64_of_bin_file(bin_file):
with open(bin_file, 'rb') as f:
data = f.read()
return base64.b64encode(data).decode()
def set_png_as_page_bg(png_file):
bin_str = get_base64_of_bin_file(png_file)
page_bg_img = '''
<style>
.stApp {
background-image: url("data:image/png;base64,%s");
background-size: cover;
background-repeat: no-repeat;
background-attachment: scroll; # doesn't work
}
</style>
''' % bin_str
st.markdown(page_bg_img, unsafe_allow_html=True)
return
set_png_as_page_bg('background.webp')
# def load_model():
# # load json and create model
# json_file = open('model.json', 'r')
# loaded_model_json = json_file.read()
# json_file.close()
# CNN_class_index = model_from_json(loaded_model_json)
# # load weights into new model
# model = CNN_class_index.load_weights("model.h5")
# #model= tf.keras.load_model('model.h5')
# #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
# return model, CNN_class_index
def load_model():
if not os.path.isfile('model.h5'):
subprocess.run(['curl --output model.h5 "https://github.com/KaburaJ/Binary-Image-classification/blob/main/ZebraHorse/CNN%20Application/model.h5"'], shell=True)
model=tf.keras.models.load_model('model.h5', compile=False)
return model
# def load_model():
# # Load the model architecture
# with open('model.json', 'r') as f:
# model_from_json(f.read())
# # Load the model weights
# model.load_weights('model.h5')
# #CNN_class_index = json.load(open(f"{os.getcwd()}F:\Machine Learning Resources\ZebraHorse\model.json"))
# return model
def image_transformation(image):
#image = Image._resize_dispatcher(image, new_shape=(256, 256))
#image= np.resize((256,256))
image = np.array(image)
np.save('images.npy', image)
image = np.load('images.npy', allow_pickle=True)
return image
# def image_prediction(image, model):
# image = image_transformation(image=image)
# outputs = float(model.predict(image))
# _, y_hat = outputs.max(1)
# predicted_idx = str(y_hat.item())
# return predicted_idx
def main():
image_file = st.file_uploader("Upload an image", type=['jpg', 'jpeg', 'png'])
if image_file:
left_column, right_column = st.columns(2)
left_column.image(image_file, caption="Uploaded image", use_column_width=True)
image_pred = image_transformation(image=Image.open(image_file))
pred_button = st.button("Predict")
model=load_model()
if model is None:
st.error("Error: Model could not be loaded")
return
# label = ['Zebra', 'Horse']
# label = np.array(label).reshape(1, -1)
# ohe= OneHotEncoder()
# labels = ohe.fit_transform(label).toarray()
if pred_button:
outputs = model.predict(int(image_pred))
_, y_hat = outputs.max(1)
predicted_idx = str(y_hat.item())
right_column.title("Prediction")
right_column.write(predicted_idx)
right_column.write(decode_predictions(outputs, top=2)[0])
if __name__ == '__main__':
main()