eaglelandsonce commited on
Commit
2f37879
1 Parent(s): a7a0009

Create 7_mnist.py

Browse files
Files changed (1) hide show
  1. pages/7_mnist.py +121 -0
pages/7_mnist.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ def load_and_preprocess_mnist():
8
+ (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
9
+
10
+ x_train = x_train.astype('float32') / 255.0
11
+ x_test = x_test.astype('float32') / 255.0
12
+
13
+ x_train = x_train.reshape((-1, 28, 28, 1))
14
+ x_test = x_test.reshape((-1, 28, 28, 1))
15
+
16
+ y_train = keras.utils.to_categorical(y_train, 10)
17
+ y_test = keras.utils.to_categorical(y_test, 10)
18
+
19
+ return (x_train, y_train), (x_test, y_test)
20
+
21
+ def create_mnist_model():
22
+ model = keras.Sequential([
23
+ keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)),
24
+ keras.layers.MaxPooling2D(pool_size=(2, 2)),
25
+ keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu'),
26
+ keras.layers.MaxPooling2D(pool_size=(2, 2)),
27
+ keras.layers.Flatten(),
28
+ keras.layers.Dropout(0.5),
29
+ keras.layers.Dense(64, activation='relu'),
30
+ keras.layers.Dense(10, activation='softmax')
31
+ ])
32
+
33
+ model.compile(optimizer='adam',
34
+ loss='categorical_crossentropy',
35
+ metrics=['accuracy'])
36
+
37
+ return model
38
+
39
+ def train_model(model, x_train, y_train, epochs, batch_size):
40
+ history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
41
+ return history
42
+
43
+ def plot_training_history(history):
44
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
45
+
46
+ ax1.plot(history.history['accuracy'], label='Training Accuracy')
47
+ ax1.plot(history.history['val_accuracy'], label='Validation Accuracy')
48
+ ax1.set_title('Model Accuracy')
49
+ ax1.set_xlabel('Epoch')
50
+ ax1.set_ylabel('Accuracy')
51
+ ax1.legend()
52
+
53
+ ax2.plot(history.history['loss'], label='Training Loss')
54
+ ax2.plot(history.history['val_loss'], label='Validation Loss')
55
+ ax2.set_title('Model Loss')
56
+ ax2.set_xlabel('Epoch')
57
+ ax2.set_ylabel('Loss')
58
+ ax2.legend()
59
+
60
+ return fig
61
+
62
+ def main():
63
+ st.title("MNIST Digit Classification with Keras and Streamlit")
64
+
65
+ # Load and preprocess data
66
+ (x_train, y_train), (x_test, y_test) = load_and_preprocess_mnist()
67
+
68
+ # Create model
69
+ if 'model' not in st.session_state:
70
+ st.session_state.model = create_mnist_model()
71
+
72
+ # Sidebar for training parameters
73
+ st.sidebar.header("Training Parameters")
74
+ epochs = st.sidebar.slider("Number of Epochs", min_value=1, max_value=50, value=10)
75
+ batch_size = st.sidebar.selectbox("Batch Size", options=[32, 64, 128, 256], index=2)
76
+
77
+ # Train model button
78
+ if st.sidebar.button("Train Model"):
79
+ with st.spinner("Training in progress..."):
80
+ history = train_model(st.session_state.model, x_train, y_train, epochs, batch_size)
81
+ st.success("Training completed!")
82
+
83
+ # Plot training history
84
+ st.subheader("Training History")
85
+ fig = plot_training_history(history)
86
+ st.pyplot(fig)
87
+
88
+ # Evaluate model
89
+ test_loss, test_acc = st.session_state.model.evaluate(x_test, y_test)
90
+ st.write(f"Test accuracy: {test_acc:.4f}")
91
+
92
+ # Set a flag to indicate the model has been trained
93
+ st.session_state.model_trained = True
94
+
95
+ # Test on random image
96
+ st.subheader("Test on Random Image")
97
+ if st.button("Select Random Image"):
98
+ if not hasattr(st.session_state, 'model_trained'):
99
+ st.warning("Please train the model first before testing.")
100
+ else:
101
+ # Select a random image from the test set
102
+ idx = np.random.randint(0, x_test.shape[0])
103
+ image = x_test[idx]
104
+ true_label = np.argmax(y_test[idx])
105
+
106
+ # Make prediction
107
+ prediction = st.session_state.model.predict(image[np.newaxis, ...])[0]
108
+ predicted_label = np.argmax(prediction)
109
+
110
+ # Display image and prediction
111
+ fig, ax = plt.subplots()
112
+ ax.imshow(image.reshape(28, 28), cmap='gray')
113
+ ax.axis('off')
114
+ st.pyplot(fig)
115
+
116
+ st.write(f"True Label: {true_label}")
117
+ st.write(f"Predicted Label: {predicted_label}")
118
+ st.write(f"Confidence: {prediction[predicted_label]:.4f}")
119
+
120
+ if __name__ == "__main__":
121
+ main()