Hem345 commited on
Commit
d32a5af
1 Parent(s): c89ca4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -36
app.py CHANGED
@@ -34,6 +34,18 @@ def create_model():
34
  model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
35
  return model
36
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Streamlit UI
38
  st.title("CNN for MNIST Classification")
39
 
@@ -42,8 +54,12 @@ model_path = "mnist_cnn_model.h5"
42
 
43
  if st.button("Train Model"):
44
  model = create_model()
 
 
 
 
45
  with st.spinner("Training..."):
46
- history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=10, batch_size=64)
47
 
48
  # Save the model
49
  model.save(model_path)
@@ -61,38 +77,4 @@ if st.button("Train Model"):
61
  ax2.plot(history.history["accuracy"], label="Train Accuracy")
62
  ax2.plot(history.history["val_accuracy"], label="Val Accuracy")
63
  ax2.set_title("Training and Validation Accuracy")
64
- ax2.set_xlabel("Epoch")
65
- ax2.set_ylabel("Accuracy")
66
- ax2.legend()
67
-
68
- st.pyplot(fig)
69
-
70
- # Evaluate the model on test data
71
- test_preds = np.argmax(model.predict(test_images), axis=1)
72
- true_labels = np.argmax(test_labels, axis=1)
73
-
74
- # Classification report
75
- report = classification_report(true_labels, test_preds, digits=4)
76
- st.text("Classification Report:")
77
- st.text(report)
78
-
79
- # Testing with a specific index
80
- index = st.number_input("Enter an index (0-9999) to test:", min_value=0, max_value=9999, step=1)
81
-
82
- def test_index_prediction(index):
83
- image = test_images[index].reshape(28, 28)
84
- st.image(image, caption=f"True Label: {true_labels[index]}", use_column_width=True)
85
-
86
- # Reload the model if needed
87
- if not os.path.exists(model_path):
88
- st.error("Train the model first.")
89
- return
90
-
91
- model = keras.models.load_model(model_path)
92
-
93
- prediction = model.predict(test_images[index].reshape(1, 28, 28, 1))
94
- predicted_class = np.argmax(prediction)
95
- st.write(f"Predicted Class: {predicted_class}")
96
-
97
- if st.button("Test Index"):
98
- test_index_prediction(index)
 
34
  model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])
35
  return model
36
 
37
+ # Custom callback for logging
38
+ class StreamlitLogger(keras.callbacks.Callback):
39
+ def on_epoch_end(self, epoch, logs=None):
40
+ if logs is None:
41
+ logs = {}
42
+
43
+ st.write(f"Epoch {epoch + 1}:")
44
+ st.write(f" Train Loss: {logs.get('loss'):.4f}")
45
+ st.write(f" Train Accuracy: {logs.get('accuracy'):.4f}")
46
+ st.write(f" Val Loss: {logs.get('val_loss'):.4f}")
47
+ st.write(f" Val Accuracy: {logs.get('val_accuracy'):.4f}")
48
+
49
  # Streamlit UI
50
  st.title("CNN for MNIST Classification")
51
 
 
54
 
55
  if st.button("Train Model"):
56
  model = create_model()
57
+
58
+ # Create logger instance
59
+ logger = StreamlitLogger()
60
+
61
  with st.spinner("Training..."):
62
+ history = model.fit(train_images, train_labels, validation_data=(test_images, test_labels), epochs=10, batch_size=64, callbacks=[logger])
63
 
64
  # Save the model
65
  model.save(model_path)
 
77
  ax2.plot(history.history["accuracy"], label="Train Accuracy")
78
  ax2.plot(history.history["val_accuracy"], label="Val Accuracy")
79
  ax2.set_title("Training and Validation Accuracy")
80
+ ax2.set_xlabel