Spaces:
Running
Running
eaglelandsonce
commited on
Commit
•
2f37879
1
Parent(s):
a7a0009
Create 7_mnist.py
Browse files- pages/7_mnist.py +121 -0
pages/7_mnist.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import tensorflow as tf
|
3 |
+
from tensorflow import keras
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
def load_and_preprocess_mnist():
|
8 |
+
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
|
9 |
+
|
10 |
+
x_train = x_train.astype('float32') / 255.0
|
11 |
+
x_test = x_test.astype('float32') / 255.0
|
12 |
+
|
13 |
+
x_train = x_train.reshape((-1, 28, 28, 1))
|
14 |
+
x_test = x_test.reshape((-1, 28, 28, 1))
|
15 |
+
|
16 |
+
y_train = keras.utils.to_categorical(y_train, 10)
|
17 |
+
y_test = keras.utils.to_categorical(y_test, 10)
|
18 |
+
|
19 |
+
return (x_train, y_train), (x_test, y_test)
|
20 |
+
|
21 |
+
def create_mnist_model():
|
22 |
+
model = keras.Sequential([
|
23 |
+
keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
|
24 |
+
keras.layers.MaxPooling2D(pool_size=(2, 2)),
|
25 |
+
keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
|
26 |
+
keras.layers.MaxPooling2D(pool_size=(2, 2)),
|
27 |
+
keras.layers.Flatten(),
|
28 |
+
keras.layers.Dropout(0.5),
|
29 |
+
keras.layers.Dense(64, activation='relu'),
|
30 |
+
keras.layers.Dense(10, activation='softmax')
|
31 |
+
])
|
32 |
+
|
33 |
+
model.compile(optimizer='adam',
|
34 |
+
loss='categorical_crossentropy',
|
35 |
+
metrics=['accuracy'])
|
36 |
+
|
37 |
+
return model
|
38 |
+
|
39 |
+
def train_model(model, x_train, y_train, epochs, batch_size):
|
40 |
+
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
|
41 |
+
return history
|
42 |
+
|
43 |
+
def plot_training_history(history):
|
44 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
45 |
+
|
46 |
+
ax1.plot(history.history['accuracy'], label='Training Accuracy')
|
47 |
+
ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
|
48 |
+
ax1.set_title('Model Accuracy')
|
49 |
+
ax1.set_xlabel('Epoch')
|
50 |
+
ax1.set_ylabel('Accuracy')
|
51 |
+
ax1.legend()
|
52 |
+
|
53 |
+
ax2.plot(history.history['loss'], label='Training Loss')
|
54 |
+
ax2.plot(history.history['val_loss'], label='Validation Loss')
|
55 |
+
ax2.set_title('Model Loss')
|
56 |
+
ax2.set_xlabel('Epoch')
|
57 |
+
ax2.set_ylabel('Loss')
|
58 |
+
ax2.legend()
|
59 |
+
|
60 |
+
return fig
|
61 |
+
|
62 |
+
def main():
|
63 |
+
st.title("MNIST Digit Classification with Keras and Streamlit")
|
64 |
+
|
65 |
+
# Load and preprocess data
|
66 |
+
(x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist()
|
67 |
+
|
68 |
+
# Create model
|
69 |
+
if 'model' not in st.session_state:
|
70 |
+
st.session_state.model = create_mnist_model()
|
71 |
+
|
72 |
+
# Sidebar for training parameters
|
73 |
+
st.sidebar.header("Training Parameters")
|
74 |
+
epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10)
|
75 |
+
batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2)
|
76 |
+
|
77 |
+
# Train model button
|
78 |
+
if st.sidebar.button("Train Model"):
|
79 |
+
with st.spinner("Training in progress..."):
|
80 |
+
history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size)
|
81 |
+
st.success("Training completed!")
|
82 |
+
|
83 |
+
# Plot training history
|
84 |
+
st.subheader("Training History")
|
85 |
+
fig = plot_training_history(history)
|
86 |
+
st.pyplot(fig)
|
87 |
+
|
88 |
+
# Evaluate model
|
89 |
+
test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test)
|
90 |
+
st.write(f"Test accuracy: {test_acc:.4f}")
|
91 |
+
|
92 |
+
# Set a flag to indicate the model has been trained
|
93 |
+
st.session_state.model_trained = True
|
94 |
+
|
95 |
+
# Test on random image
|
96 |
+
st.subheader("Test on Random Image")
|
97 |
+
if st.button("Select Random Image"):
|
98 |
+
if not hasattr(st.session_state, 'model_trained'):
|
99 |
+
st.warning("Please train the model first before testing.")
|
100 |
+
else:
|
101 |
+
# Select a random image from the test set
|
102 |
+
idx = np.random.randint(0, x_test.shape[0])
|
103 |
+
image = x_test[idx]
|
104 |
+
true_label = np.argmax(y_test[idx])
|
105 |
+
|
106 |
+
# Make prediction
|
107 |
+
prediction = st.session_state.model.predict(image[np.newaxis, ...])[0]
|
108 |
+
predicted_label = np.argmax(prediction)
|
109 |
+
|
110 |
+
# Display image and prediction
|
111 |
+
fig, ax = plt.subplots()
|
112 |
+
ax.imshow(image.reshape(28, 28), cmap='gray')
|
113 |
+
ax.axis('off')
|
114 |
+
st.pyplot(fig)
|
115 |
+
|
116 |
+
st.write(f"True Label: {true_label}")
|
117 |
+
st.write(f"Predicted Label: {predicted_label}")
|
118 |
+
st.write(f"Confidence: {prediction[predicted_label]:.4f}")
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
main()
|