Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import pandas as pd | |
import torch | |
import matplotlib.pyplot as plt | |
from mlp_utils import ( | |
MLP, generate_dataset, split_data, train_model, plot_training_history, | |
visualize_weights, plot_weight_optimization, visualize_network, | |
plot_confusion_matrix, plot_classification_metrics, ACTIVATION_MAP | |
) | |
st.set_page_config(page_title="Interactive MLP Learning Platform", layout="wide") | |
st.title("Interactive MLP Learning Platform") | |
st.markdown(""" | |
This application helps you learn about Multi-Layer Perceptrons (MLPs) through interactive experimentation. | |
You can generate synthetic data, design your own MLP architecture, and observe the training process. | |
""") | |
# Sidebar for dataset configuration | |
st.sidebar.header("Dataset Configuration") | |
n_samples = st.sidebar.slider("Number of Samples", 100, 1000, 500) | |
n_features = st.sidebar.slider("Number of Features", 2, 10, 4) | |
n_classes = st.sidebar.slider("Number of Classes", 2, 5, 3) | |
# Data split percentages | |
st.sidebar.subheader("Data Split (%)") | |
def_percent = 20 | |
val_percent = st.sidebar.slider("Validation %", 0, 50, def_percent) | |
test_percent = st.sidebar.slider("Test %", 0, 50, def_percent) | |
train_percent = 100 - val_percent - test_percent | |
if train_percent < 1: | |
st.sidebar.error("Train % must be at least 1%.") | |
# Generate dataset | |
if st.sidebar.button("Generate Dataset"): | |
X, y = generate_dataset(n_samples, n_features, n_classes) | |
(X_train, y_train), (X_val, y_val), (X_test, y_test) = split_data( | |
X, y, val_percent/100, test_percent/100) | |
st.session_state['X_train'] = X_train | |
st.session_state['y_train'] = y_train | |
st.session_state['X_val'] = X_val | |
st.session_state['y_val'] = y_val | |
st.session_state['X_test'] = X_test | |
st.session_state['y_test'] = y_test | |
st.session_state['dataset_generated'] = True | |
st.session_state['network_confirmed'] = False | |
st.session_state['training_complete'] = False | |
st.session_state['testing_complete'] = False | |
# Main content area | |
if 'dataset_generated' in st.session_state: | |
st.header("Dataset Information") | |
st.write(f"Train: {len(st.session_state['X_train'])} samples | " | |
f"Validation: {len(st.session_state['X_val'])} samples | " | |
f"Test: {len(st.session_state['X_test'])} samples") | |
# Display dataset statistics | |
df = pd.DataFrame(st.session_state['X_train'], columns=[f'Feature {i+1}' for i in range(n_features)]) | |
df['Class'] = st.session_state['y_train'] | |
st.subheader("Training Set Preview") | |
st.dataframe(df.head()) | |
# MLP Configuration | |
st.header("MLP Configuration") | |
n_hidden_layers = st.slider("Number of Hidden Layers", 1, 5, 2) | |
hidden_sizes = [] | |
activations = [] | |
activation_options = list(ACTIVATION_MAP.keys()) | |
for i in range(n_hidden_layers): | |
cols = st.columns([2, 2]) | |
with cols[0]: | |
size = st.slider(f"Nodes in Hidden Layer {i+1}", 2, 20, 8, key=f"hsize_{i}") | |
with cols[1]: | |
act = st.selectbox(f"Activation for Layer {i+1}", activation_options[:-1], index=0, key=f"act_{i}") | |
hidden_sizes.append(size) | |
activations.append(act) | |
# Add activation for input to first hidden | |
activations = [activations[0]] + activations | |
# Confirm network button | |
if st.button("Confirm Network"): | |
st.session_state['hidden_sizes'] = hidden_sizes | |
st.session_state['activations'] = activations | |
st.session_state['network_confirmed'] = True | |
st.session_state['training_complete'] = False | |
st.session_state['testing_complete'] = False | |
# Show network configuration | |
if st.session_state.get('network_confirmed', False): | |
st.subheader("Network Architecture Visualization") | |
fig = visualize_network(n_features, hidden_sizes, n_classes) | |
st.pyplot(fig) | |
st.write(f"Input: {n_features} | Hidden: {hidden_sizes} | Output: {n_classes}") | |
st.write(f"Activations: {st.session_state['activations']}") | |
# Training parameters | |
st.subheader("Training Parameters") | |
epochs = st.slider("Number of Epochs", 10, 200, 50) | |
learning_rate = st.slider("Learning Rate", 0.001, 0.1, 0.01, 0.001) | |
batch_size = st.slider("Batch Size", 8, 128, 32) | |
# Train button | |
if st.button("Train MLP"): | |
model = MLP(n_features, hidden_sizes, n_classes, st.session_state['activations']) | |
train_losses, train_accuracies, val_losses, val_accuracies, weights_history = train_model( | |
model, | |
st.session_state['X_train'], | |
st.session_state['y_train'], | |
st.session_state['X_val'], | |
st.session_state['y_val'], | |
epochs, | |
learning_rate, | |
batch_size, | |
track_weights=True | |
) | |
st.session_state['model'] = model | |
st.session_state['train_losses'] = train_losses | |
st.session_state['train_accuracies'] = train_accuracies | |
st.session_state['val_losses'] = val_losses | |
st.session_state['val_accuracies'] = val_accuracies | |
st.session_state['weights_history'] = weights_history | |
st.session_state['training_complete'] = True | |
st.session_state['testing_complete'] = False | |
# Show training results | |
if st.session_state.get('training_complete', False): | |
st.header("Training Results") | |
fig = plot_training_history( | |
st.session_state['train_losses'], | |
st.session_state['train_accuracies'], | |
st.session_state['val_losses'], | |
st.session_state['val_accuracies'] | |
) | |
st.pyplot(fig) | |
st.subheader("Weight Visualization (All Layers)") | |
weight_fig = visualize_weights(st.session_state['model']) | |
st.pyplot(weight_fig) | |
st.subheader("Weight Optimization (First Layer)") | |
opt_fig = plot_weight_optimization(st.session_state['weights_history']) | |
st.pyplot(opt_fig) | |
col1, col2, col3, col4 = st.columns(4) | |
with col1: | |
st.metric("Final Training Loss", f"{st.session_state['train_losses'][-1]:.4f}") | |
with col2: | |
st.metric("Final Training Accuracy", f"{st.session_state['train_accuracies'][-1]:.2%}") | |
with col3: | |
st.metric("Final Validation Loss", f"{st.session_state['val_losses'][-1]:.4f}") | |
with col4: | |
st.metric("Final Validation Accuracy", f"{st.session_state['val_accuracies'][-1]:.2%}") | |
# Test button | |
if st.button("Test on Unseen Data"): | |
model = st.session_state['model'] | |
X_test = st.session_state['X_test'] | |
y_test = st.session_state['y_test'] | |
model.eval() | |
with torch.no_grad(): | |
X_tensor = torch.FloatTensor(X_test) | |
outputs = model(X_tensor) | |
_, predicted = torch.max(outputs.data, 1) | |
test_accuracy = (predicted.numpy() == y_test).mean() | |
st.session_state['test_accuracy'] = test_accuracy | |
st.session_state['test_predictions'] = predicted.numpy() | |
st.session_state['testing_complete'] = True | |
if st.session_state.get('testing_complete', False): | |
st.header("Test Results") | |
st.success(f"Test Accuracy: {st.session_state['test_accuracy']:.2%}") | |
# Confusion Matrix | |
st.subheader("Confusion Matrix") | |
cm_fig = plot_confusion_matrix( | |
st.session_state['y_test'], | |
st.session_state['test_predictions'], | |
n_classes | |
) | |
st.pyplot(cm_fig) | |
# Classification Metrics | |
st.subheader("Classification Metrics") | |
metrics_df = plot_classification_metrics( | |
st.session_state['y_test'], | |
st.session_state['test_predictions'], | |
n_classes | |
) | |
st.dataframe(metrics_df) | |
# Additional Test Metrics | |
st.subheader("Additional Test Metrics") | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Test Accuracy", f"{st.session_state['test_accuracy']:.2%}") | |
with col2: | |
st.metric("Test Error Rate", f"{1 - st.session_state['test_accuracy']:.2%}") | |
else: | |
st.info("Please generate a dataset using the sidebar controls to begin.") |