Spaces:
Build error
Build error
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| from tensorflow.keras.models import load_model | |
| # Install the streamlit_drawable_canvas package if you haven't already | |
| # !pip install streamlit_drawable_canvas | |
| # Import the st_canvas function | |
| from streamlit_drawable_canvas import st_canvas | |
| # Function to preprocess the drawn image | |
| def preprocess_image(drawing, size=(28, 28)): | |
| # Convert the drawing to a PIL Image | |
| img = Image.fromarray(np.uint8(drawing)) | |
| # Resize the image to the desired size | |
| img = img.resize(size) | |
| # Convert the image to grayscale | |
| img = img.convert('L') | |
| # Convert the image to a numpy array | |
| img_array = np.array(img) | |
| # Normalize the pixel values to be between 0 and 1 | |
| img_array = img_array / 255.0 | |
| # Add a channel dimension (1 channel for grayscale) | |
| img_array = np.expand_dims(img_array, axis=-1) | |
| return img_array | |
| def preprocess_and_predict(image): | |
| model = load_model("mnist_cnn_model.h5") | |
| # Expand dimensions to match the input shape expected by the model | |
| image = np.expand_dims(image, axis=0) | |
| # Reshape to match the input shape expected by the model | |
| image = np.reshape(image, (1, 28, 28, 1)) | |
| prediction = model.predict(image) | |
| predicted_class = np.argmax(prediction) | |
| return predicted_class | |
| # Main code | |
| def main(): | |
| st.title('Draw Digit') | |
| # Create a drawing canvas | |
| drawing = st_canvas( | |
| fill_color="rgb(0, 0, 0)", # Background color of the canvas | |
| stroke_width=4, # Stroke width | |
| stroke_color="rgb(255, 255, 255)", # Stroke color | |
| background_color="#000000", # Background color of the canvas component | |
| height=168, # Height of the canvas | |
| width=168, # Width of the canvas | |
| drawing_mode="freedraw", # Drawing mode: "freedraw" or "transform" | |
| key="canvas", | |
| ) | |
| predict = st.button('Predict digit') | |
| # Check if the user has drawn anything | |
| if predict is True: | |
| # Preprocess the drawn image | |
| processed_image = preprocess_image(drawing.image_data) | |
| digit_class = preprocess_and_predict(processed_image) | |
| st.title("Predicted Digit:") | |
| st.success(digit_class) | |
| predict = False | |
| if __name__ == "__main__": | |
| main() | |