MLP-visualizer / app.py
Umar1623's picture
Upload app.py
0a32aab verified
import streamlit as st
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from mlp_utils import (
MLP, generate_dataset, split_data, train_model, plot_training_history,
visualize_weights, plot_weight_optimization, visualize_network,
plot_confusion_matrix, plot_classification_metrics, ACTIVATION_MAP
)
st.set_page_config(page_title="Interactive MLP Learning Platform", layout="wide")
st.title("Interactive MLP Learning Platform")
st.markdown("""
This application helps you learn about Multi-Layer Perceptrons (MLPs) through interactive experimentation.
You can generate synthetic data, design your own MLP architecture, and observe the training process.
""")
# Sidebar for dataset configuration
st.sidebar.header("Dataset Configuration")
n_samples = st.sidebar.slider("Number of Samples", 100, 1000, 500)
n_features = st.sidebar.slider("Number of Features", 2, 10, 4)
n_classes = st.sidebar.slider("Number of Classes", 2, 5, 3)
# Data split percentages
st.sidebar.subheader("Data Split (%)")
def_percent = 20
val_percent = st.sidebar.slider("Validation %", 0, 50, def_percent)
test_percent = st.sidebar.slider("Test %", 0, 50, def_percent)
train_percent = 100 - val_percent - test_percent
if train_percent < 1:
st.sidebar.error("Train % must be at least 1%.")
# Generate dataset
if st.sidebar.button("Generate Dataset"):
X, y = generate_dataset(n_samples, n_features, n_classes)
(X_train, y_train), (X_val, y_val), (X_test, y_test) = split_data(
X, y, val_percent/100, test_percent/100)
st.session_state['X_train'] = X_train
st.session_state['y_train'] = y_train
st.session_state['X_val'] = X_val
st.session_state['y_val'] = y_val
st.session_state['X_test'] = X_test
st.session_state['y_test'] = y_test
st.session_state['dataset_generated'] = True
st.session_state['network_confirmed'] = False
st.session_state['training_complete'] = False
st.session_state['testing_complete'] = False
# Main content area
if 'dataset_generated' in st.session_state:
st.header("Dataset Information")
st.write(f"Train: {len(st.session_state['X_train'])} samples | "
f"Validation: {len(st.session_state['X_val'])} samples | "
f"Test: {len(st.session_state['X_test'])} samples")
# Display dataset statistics
df = pd.DataFrame(st.session_state['X_train'], columns=[f'Feature {i+1}' for i in range(n_features)])
df['Class'] = st.session_state['y_train']
st.subheader("Training Set Preview")
st.dataframe(df.head())
# MLP Configuration
st.header("MLP Configuration")
n_hidden_layers = st.slider("Number of Hidden Layers", 1, 5, 2)
hidden_sizes = []
activations = []
activation_options = list(ACTIVATION_MAP.keys())
for i in range(n_hidden_layers):
cols = st.columns([2, 2])
with cols[0]:
size = st.slider(f"Nodes in Hidden Layer {i+1}", 2, 20, 8, key=f"hsize_{i}")
with cols[1]:
act = st.selectbox(f"Activation for Layer {i+1}", activation_options[:-1], index=0, key=f"act_{i}")
hidden_sizes.append(size)
activations.append(act)
# Add activation for input to first hidden
activations = [activations[0]] + activations
# Confirm network button
if st.button("Confirm Network"):
st.session_state['hidden_sizes'] = hidden_sizes
st.session_state['activations'] = activations
st.session_state['network_confirmed'] = True
st.session_state['training_complete'] = False
st.session_state['testing_complete'] = False
# Show network configuration
if st.session_state.get('network_confirmed', False):
st.subheader("Network Architecture Visualization")
fig = visualize_network(n_features, hidden_sizes, n_classes)
st.pyplot(fig)
st.write(f"Input: {n_features} | Hidden: {hidden_sizes} | Output: {n_classes}")
st.write(f"Activations: {st.session_state['activations']}")
# Training parameters
st.subheader("Training Parameters")
epochs = st.slider("Number of Epochs", 10, 200, 50)
learning_rate = st.slider("Learning Rate", 0.001, 0.1, 0.01, 0.001)
batch_size = st.slider("Batch Size", 8, 128, 32)
# Train button
if st.button("Train MLP"):
model = MLP(n_features, hidden_sizes, n_classes, st.session_state['activations'])
train_losses, train_accuracies, val_losses, val_accuracies, weights_history = train_model(
model,
st.session_state['X_train'],
st.session_state['y_train'],
st.session_state['X_val'],
st.session_state['y_val'],
epochs,
learning_rate,
batch_size,
track_weights=True
)
st.session_state['model'] = model
st.session_state['train_losses'] = train_losses
st.session_state['train_accuracies'] = train_accuracies
st.session_state['val_losses'] = val_losses
st.session_state['val_accuracies'] = val_accuracies
st.session_state['weights_history'] = weights_history
st.session_state['training_complete'] = True
st.session_state['testing_complete'] = False
# Show training results
if st.session_state.get('training_complete', False):
st.header("Training Results")
fig = plot_training_history(
st.session_state['train_losses'],
st.session_state['train_accuracies'],
st.session_state['val_losses'],
st.session_state['val_accuracies']
)
st.pyplot(fig)
st.subheader("Weight Visualization (All Layers)")
weight_fig = visualize_weights(st.session_state['model'])
st.pyplot(weight_fig)
st.subheader("Weight Optimization (First Layer)")
opt_fig = plot_weight_optimization(st.session_state['weights_history'])
st.pyplot(opt_fig)
col1, col2, col3, col4 = st.columns(4)
with col1:
st.metric("Final Training Loss", f"{st.session_state['train_losses'][-1]:.4f}")
with col2:
st.metric("Final Training Accuracy", f"{st.session_state['train_accuracies'][-1]:.2%}")
with col3:
st.metric("Final Validation Loss", f"{st.session_state['val_losses'][-1]:.4f}")
with col4:
st.metric("Final Validation Accuracy", f"{st.session_state['val_accuracies'][-1]:.2%}")
# Test button
if st.button("Test on Unseen Data"):
model = st.session_state['model']
X_test = st.session_state['X_test']
y_test = st.session_state['y_test']
model.eval()
with torch.no_grad():
X_tensor = torch.FloatTensor(X_test)
outputs = model(X_tensor)
_, predicted = torch.max(outputs.data, 1)
test_accuracy = (predicted.numpy() == y_test).mean()
st.session_state['test_accuracy'] = test_accuracy
st.session_state['test_predictions'] = predicted.numpy()
st.session_state['testing_complete'] = True
if st.session_state.get('testing_complete', False):
st.header("Test Results")
st.success(f"Test Accuracy: {st.session_state['test_accuracy']:.2%}")
# Confusion Matrix
st.subheader("Confusion Matrix")
cm_fig = plot_confusion_matrix(
st.session_state['y_test'],
st.session_state['test_predictions'],
n_classes
)
st.pyplot(cm_fig)
# Classification Metrics
st.subheader("Classification Metrics")
metrics_df = plot_classification_metrics(
st.session_state['y_test'],
st.session_state['test_predictions'],
n_classes
)
st.dataframe(metrics_df)
# Additional Test Metrics
st.subheader("Additional Test Metrics")
col1, col2 = st.columns(2)
with col1:
st.metric("Test Accuracy", f"{st.session_state['test_accuracy']:.2%}")
with col2:
st.metric("Test Error Rate", f"{1 - st.session_state['test_accuracy']:.2%}")
else:
st.info("Please generate a dataset using the sidebar controls to begin.")