import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image, ImageOps from tensorflow.keras.utils import img_to_array from streamlit_drawable_canvas import st_canvas import streamlit as st # st.set_page_config(layout="wide") st.write('# MNIST Digit Recognition') st.write('## Using a CNN `Keras` model') # Import Pre-trained Model model = tf.keras.models.load_model('mnist.h5') plt.rcParams.update({'font.size': 18}) # Create a sidebar to hold the settings stroke_width = st.sidebar.slider("Stroke width: ", 1, 25, 9) realtime_update = st.sidebar.checkbox("Update in realtime", True) canvas_result = st_canvas( fill_color="rgba(255, 165, 0, 0.3)", # Fixed fill color with some opacity stroke_width=stroke_width, stroke_color='#FFFFFF', background_color='#000000', #background_image=Image.open(bg_image) if bg_image else None, update_streamlit=realtime_update, height=224, width=224, drawing_mode='freedraw', key="canvas", ) if canvas_result.image_data is not None: st.write('### Resized Image') st.write("The image needs to be resized, because it can only input 28x28 images") # st.image(canvas_result.image_data) # st.write(type(canvas_result.image_data)) # st.write(canvas_result.image_data.shape) # st.write(canvas_result.image_data) im = ImageOps.grayscale(Image.fromarray(canvas_result.image_data.astype( 'uint8'), mode="RGBA")).resize((28, 28)) # img_data = im. st.image(im, width=224) data = img_to_array(im) data = data / 255 data = data.reshape(1, 28, 28, 1) data = data.astype('float32') st.write('### Predicted Digit') prediction = model.predict(data) result = plt.figure(figsize=(12, 3)) plt.bar(range(10), prediction[0]) plt.xticks(range(10)) plt.xlabel('Digit') plt.ylabel('Probability') plt.title('Drawing Prediction') plt.ylim(0, 1) st.write(result)