Umar1623 commited on
Commit
0a32aab
·
verified ·
1 Parent(s): 92c551b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -0
app.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ import matplotlib.pyplot as plt
6
+ from mlp_utils import (
7
+ MLP, generate_dataset, split_data, train_model, plot_training_history,
8
+ visualize_weights, plot_weight_optimization, visualize_network,
9
+ plot_confusion_matrix, plot_classification_metrics, ACTIVATION_MAP
10
+ )
11
+
12
+ st.set_page_config(page_title="Interactive MLP Learning Platform", layout="wide")
13
+
14
+ st.title("Interactive MLP Learning Platform")
15
+ st.markdown("""
16
+ This application helps you learn about Multi-Layer Perceptrons (MLPs) through interactive experimentation.
17
+ You can generate synthetic data, design your own MLP architecture, and observe the training process.
18
+ """)
19
+
20
+ # Sidebar for dataset configuration
21
+ st.sidebar.header("Dataset Configuration")
22
+ n_samples = st.sidebar.slider("Number of Samples", 100, 1000, 500)
23
+ n_features = st.sidebar.slider("Number of Features", 2, 10, 4)
24
+ n_classes = st.sidebar.slider("Number of Classes", 2, 5, 3)
25
+
26
+ # Data split percentages
27
+ st.sidebar.subheader("Data Split (%)")
28
+ def_percent = 20
29
+ val_percent = st.sidebar.slider("Validation %", 0, 50, def_percent)
30
+ test_percent = st.sidebar.slider("Test %", 0, 50, def_percent)
31
+ train_percent = 100 - val_percent - test_percent
32
+ if train_percent < 1:
33
+ st.sidebar.error("Train % must be at least 1%.")
34
+
35
+ # Generate dataset
36
+ if st.sidebar.button("Generate Dataset"):
37
+ X, y = generate_dataset(n_samples, n_features, n_classes)
38
+ (X_train, y_train), (X_val, y_val), (X_test, y_test) = split_data(
39
+ X, y, val_percent/100, test_percent/100)
40
+ st.session_state['X_train'] = X_train
41
+ st.session_state['y_train'] = y_train
42
+ st.session_state['X_val'] = X_val
43
+ st.session_state['y_val'] = y_val
44
+ st.session_state['X_test'] = X_test
45
+ st.session_state['y_test'] = y_test
46
+ st.session_state['dataset_generated'] = True
47
+ st.session_state['network_confirmed'] = False
48
+ st.session_state['training_complete'] = False
49
+ st.session_state['testing_complete'] = False
50
+
51
+ # Main content area
52
+ if 'dataset_generated' in st.session_state:
53
+ st.header("Dataset Information")
54
+ st.write(f"Train: {len(st.session_state['X_train'])} samples | "
55
+ f"Validation: {len(st.session_state['X_val'])} samples | "
56
+ f"Test: {len(st.session_state['X_test'])} samples")
57
+
58
+ # Display dataset statistics
59
+ df = pd.DataFrame(st.session_state['X_train'], columns=[f'Feature {i+1}' for i in range(n_features)])
60
+ df['Class'] = st.session_state['y_train']
61
+ st.subheader("Training Set Preview")
62
+ st.dataframe(df.head())
63
+
64
+ # MLP Configuration
65
+ st.header("MLP Configuration")
66
+ n_hidden_layers = st.slider("Number of Hidden Layers", 1, 5, 2)
67
+ hidden_sizes = []
68
+ activations = []
69
+ activation_options = list(ACTIVATION_MAP.keys())
70
+ for i in range(n_hidden_layers):
71
+ cols = st.columns([2, 2])
72
+ with cols[0]:
73
+ size = st.slider(f"Nodes in Hidden Layer {i+1}", 2, 20, 8, key=f"hsize_{i}")
74
+ with cols[1]:
75
+ act = st.selectbox(f"Activation for Layer {i+1}", activation_options[:-1], index=0, key=f"act_{i}")
76
+ hidden_sizes.append(size)
77
+ activations.append(act)
78
+ # Add activation for input to first hidden
79
+ activations = [activations[0]] + activations
80
+
81
+ # Confirm network button
82
+ if st.button("Confirm Network"):
83
+ st.session_state['hidden_sizes'] = hidden_sizes
84
+ st.session_state['activations'] = activations
85
+ st.session_state['network_confirmed'] = True
86
+ st.session_state['training_complete'] = False
87
+ st.session_state['testing_complete'] = False
88
+
89
+ # Show network configuration
90
+ if st.session_state.get('network_confirmed', False):
91
+ st.subheader("Network Architecture Visualization")
92
+ fig = visualize_network(n_features, hidden_sizes, n_classes)
93
+ st.pyplot(fig)
94
+ st.write(f"Input: {n_features} | Hidden: {hidden_sizes} | Output: {n_classes}")
95
+ st.write(f"Activations: {st.session_state['activations']}")
96
+
97
+ # Training parameters
98
+ st.subheader("Training Parameters")
99
+ epochs = st.slider("Number of Epochs", 10, 200, 50)
100
+ learning_rate = st.slider("Learning Rate", 0.001, 0.1, 0.01, 0.001)
101
+ batch_size = st.slider("Batch Size", 8, 128, 32)
102
+
103
+ # Train button
104
+ if st.button("Train MLP"):
105
+ model = MLP(n_features, hidden_sizes, n_classes, st.session_state['activations'])
106
+ train_losses, train_accuracies, val_losses, val_accuracies, weights_history = train_model(
107
+ model,
108
+ st.session_state['X_train'],
109
+ st.session_state['y_train'],
110
+ st.session_state['X_val'],
111
+ st.session_state['y_val'],
112
+ epochs,
113
+ learning_rate,
114
+ batch_size,
115
+ track_weights=True
116
+ )
117
+ st.session_state['model'] = model
118
+ st.session_state['train_losses'] = train_losses
119
+ st.session_state['train_accuracies'] = train_accuracies
120
+ st.session_state['val_losses'] = val_losses
121
+ st.session_state['val_accuracies'] = val_accuracies
122
+ st.session_state['weights_history'] = weights_history
123
+ st.session_state['training_complete'] = True
124
+ st.session_state['testing_complete'] = False
125
+
126
+ # Show training results
127
+ if st.session_state.get('training_complete', False):
128
+ st.header("Training Results")
129
+ fig = plot_training_history(
130
+ st.session_state['train_losses'],
131
+ st.session_state['train_accuracies'],
132
+ st.session_state['val_losses'],
133
+ st.session_state['val_accuracies']
134
+ )
135
+ st.pyplot(fig)
136
+
137
+ st.subheader("Weight Visualization (All Layers)")
138
+ weight_fig = visualize_weights(st.session_state['model'])
139
+ st.pyplot(weight_fig)
140
+
141
+ st.subheader("Weight Optimization (First Layer)")
142
+ opt_fig = plot_weight_optimization(st.session_state['weights_history'])
143
+ st.pyplot(opt_fig)
144
+
145
+ col1, col2, col3, col4 = st.columns(4)
146
+ with col1:
147
+ st.metric("Final Training Loss", f"{st.session_state['train_losses'][-1]:.4f}")
148
+ with col2:
149
+ st.metric("Final Training Accuracy", f"{st.session_state['train_accuracies'][-1]:.2%}")
150
+ with col3:
151
+ st.metric("Final Validation Loss", f"{st.session_state['val_losses'][-1]:.4f}")
152
+ with col4:
153
+ st.metric("Final Validation Accuracy", f"{st.session_state['val_accuracies'][-1]:.2%}")
154
+
155
+ # Test button
156
+ if st.button("Test on Unseen Data"):
157
+ model = st.session_state['model']
158
+ X_test = st.session_state['X_test']
159
+ y_test = st.session_state['y_test']
160
+ model.eval()
161
+ with torch.no_grad():
162
+ X_tensor = torch.FloatTensor(X_test)
163
+ outputs = model(X_tensor)
164
+ _, predicted = torch.max(outputs.data, 1)
165
+ test_accuracy = (predicted.numpy() == y_test).mean()
166
+
167
+ st.session_state['test_accuracy'] = test_accuracy
168
+ st.session_state['test_predictions'] = predicted.numpy()
169
+ st.session_state['testing_complete'] = True
170
+
171
+ if st.session_state.get('testing_complete', False):
172
+ st.header("Test Results")
173
+ st.success(f"Test Accuracy: {st.session_state['test_accuracy']:.2%}")
174
+
175
+ # Confusion Matrix
176
+ st.subheader("Confusion Matrix")
177
+ cm_fig = plot_confusion_matrix(
178
+ st.session_state['y_test'],
179
+ st.session_state['test_predictions'],
180
+ n_classes
181
+ )
182
+ st.pyplot(cm_fig)
183
+
184
+ # Classification Metrics
185
+ st.subheader("Classification Metrics")
186
+ metrics_df = plot_classification_metrics(
187
+ st.session_state['y_test'],
188
+ st.session_state['test_predictions'],
189
+ n_classes
190
+ )
191
+ st.dataframe(metrics_df)
192
+
193
+ # Additional Test Metrics
194
+ st.subheader("Additional Test Metrics")
195
+ col1, col2 = st.columns(2)
196
+ with col1:
197
+ st.metric("Test Accuracy", f"{st.session_state['test_accuracy']:.2%}")
198
+ with col2:
199
+ st.metric("Test Error Rate", f"{1 - st.session_state['test_accuracy']:.2%}")
200
+ else:
201
+ st.info("Please generate a dataset using the sidebar controls to begin.")