Hem345 commited on
Commit
6f3b8fe
1 Parent(s): d32a5af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -11
app.py CHANGED
@@ -34,28 +34,27 @@ def create_model():
34
  model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
35
  return model
36
 
 
 
 
 
 
 
37
  # Custom callback for logging
38
  class StreamlitLogger(keras.callbacks.Callback):
39
  def on_epoch_end(self, epoch, logs=None):
40
- if logs is None:
41
  logs = {}
42
-
43
  st.write(f"Epoch {epoch + 1}:")
44
  st.write(f" Train Loss: {logs.get('loss'):.4f}")
45
  st.write(f" Train Accuracy: {logs.get('accuracy'):.4f}")
46
  st.write(f" Val Loss: {logs.get('val_loss'):.4f}")
47
  st.write(f" Val Accuracy: {logs.get('val_accuracy'):.4f}")
48
 
49
- # Streamlit UI
50
- st.title("CNN for MNIST Classification")
51
-
52
- # Check if model is saved
53
- model_path = "mnist_cnn_model.h5"
54
-
55
  if st.button("Train Model"):
56
  model = create_model()
57
 
58
- # Create logger instance
59
  logger = StreamlitLogger()
60
 
61
  with st.spinner("Training..."):
@@ -72,9 +71,44 @@ if st.button("Train Model"):
72
  ax1.set_title("Training and Validation Loss")
73
  ax1.set_xlabel("Epoch")
74
  ax1.set_ylabel("Loss")
75
- ax1.legend()
76
 
77
  ax2.plot(history.history["accuracy"], label="Train Accuracy")
78
  ax2.plot(history.history["val_accuracy"], label="Val Accuracy")
79
  ax2.set_title("Training and Validation Accuracy")
80
- ax2.set_xlabel
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
35
  return model
36
 
37
+ # Streamlit UI
38
+ st.title("CNN for MNIST Classification")
39
+
40
+ # Check if model is saved
41
+ model_path = "mnist_cnn_model.h5"
42
+
43
  # Custom callback for logging
44
  class StreamlitLogger(keras.callbacks.Callback):
45
  def on_epoch_end(self, epoch, logs=None):
46
+ if logs is none:
47
  logs = {}
48
+
49
  st.write(f"Epoch {epoch + 1}:")
50
  st.write(f" Train Loss: {logs.get('loss'):.4f}")
51
  st.write(f" Train Accuracy: {logs.get('accuracy'):.4f}")
52
  st.write(f" Val Loss: {logs.get('val_loss'):.4f}")
53
  st.write(f" Val Accuracy: {logs.get('val_accuracy'):.4f}")
54
 
 
 
 
 
 
 
55
  if st.button("Train Model"):
56
  model = create_model()
57
 
 
58
  logger = StreamlitLogger()
59
 
60
  with st.spinner("Training..."):
 
71
  ax1.set_title("Training and Validation Loss")
72
  ax1.set_xlabel("Epoch")
73
  ax1.set_ylabel("Loss")
 
74
 
75
  ax2.plot(history.history["accuracy"], label="Train Accuracy")
76
  ax2.plot(history.history["val_accuracy"], label="Val Accuracy")
77
  ax2.set_title("Training and Validation Accuracy")
78
+ ax2.set_xlabel("Epoch")
79
+ ax2.set_ylabel("Accuracy")
80
+
81
+ ax1.legend()
82
+ ax2.legend()
83
+
84
+ st.pyplot(fig)
85
+
86
+ # Evaluate the model on test data
87
+ test_preds = np.argmax(model.predict(test_images), axis=1)
88
+ true_labels = np.argmax(test_labels, axis=1)
89
+
90
+ st.session_state['true_labels'] = true_labels
91
+
92
+ report = classification_report(true_labels, test_preds, digits=4)
93
+ st.text("Classification Report:")
94
+ st.text(report)
95
+
96
+ index = st.number_input("Enter an index (0-9999) to test:", min_value=0, max_value=9999, step=1)
97
+
98
+ def test_index_prediction(index):
99
+ image = test_images[index].reshape(28, 28)
100
+ st.image(image, caption=f"True Label: {st.session_state['true_labels'][index]}", use_column_width=True)
101
+
102
+ # Reload the model
103
+ if not os.path.exists(model_path):
104
+ st.error("Train the model first.")
105
+ return
106
+
107
+ model = keras.models.load_model(model_path)
108
+
109
+ prediction = model.predict(test_images[index].reshape(1, 28, 28, 1))
110
+ predicted_class = np.argmax(prediction)
111
+ st.write(f"Predicted Class: {predicted_class}")
112
+
113
+ if st.button("Test Index"):
114
+ test_index_prediction(index)