import streamlit as st import numpy as np import pandas as pd import torch import matplotlib.pyplot as plt from mlp_utils import ( MLP, generate_dataset, split_data, train_model, plot_training_history, visualize_weights, plot_weight_optimization, visualize_network, plot_confusion_matrix, plot_classification_metrics, ACTIVATION_MAP ) st.set_page_config(page_title="Interactive MLP Learning Platform", layout="wide") st.title("Interactive MLP Learning Platform") st.markdown(""" This application helps you learn about Multi-Layer Perceptrons (MLPs) through interactive experimentation. You can generate synthetic data, design your own MLP architecture, and observe the training process. """) # Sidebar for dataset configuration st.sidebar.header("Dataset Configuration") n_samples = st.sidebar.slider("Number of Samples", 100, 1000, 500) n_features = st.sidebar.slider("Number of Features", 2, 10, 4) n_classes = st.sidebar.slider("Number of Classes", 2, 5, 3) # Data split percentages st.sidebar.subheader("Data Split (%)") def_percent = 20 val_percent = st.sidebar.slider("Validation %", 0, 50, def_percent) test_percent = st.sidebar.slider("Test %", 0, 50, def_percent) train_percent = 100 - val_percent - test_percent if train_percent < 1: st.sidebar.error("Train % must be at least 1%.") # Generate dataset if st.sidebar.button("Generate Dataset"): X, y = generate_dataset(n_samples, n_features, n_classes) (X_train, y_train), (X_val, y_val), (X_test, y_test) = split_data( X, y, val_percent/100, test_percent/100) st.session_state['X_train'] = X_train st.session_state['y_train'] = y_train st.session_state['X_val'] = X_val st.session_state['y_val'] = y_val st.session_state['X_test'] = X_test st.session_state['y_test'] = y_test st.session_state['dataset_generated'] = True st.session_state['network_confirmed'] = False st.session_state['training_complete'] = False st.session_state['testing_complete'] = False # Main content area if 'dataset_generated' in st.session_state: st.header("Dataset Information") st.write(f"Train: {len(st.session_state['X_train'])} samples | " f"Validation: {len(st.session_state['X_val'])} samples | " f"Test: {len(st.session_state['X_test'])} samples") # Display dataset statistics df = pd.DataFrame(st.session_state['X_train'], columns=[f'Feature {i+1}' for i in range(n_features)]) df['Class'] = st.session_state['y_train'] st.subheader("Training Set Preview") st.dataframe(df.head()) # MLP Configuration st.header("MLP Configuration") n_hidden_layers = st.slider("Number of Hidden Layers", 1, 5, 2) hidden_sizes = [] activations = [] activation_options = list(ACTIVATION_MAP.keys()) for i in range(n_hidden_layers): cols = st.columns([2, 2]) with cols[0]: size = st.slider(f"Nodes in Hidden Layer {i+1}", 2, 20, 8, key=f"hsize_{i}") with cols[1]: act = st.selectbox(f"Activation for Layer {i+1}", activation_options[:-1], index=0, key=f"act_{i}") hidden_sizes.append(size) activations.append(act) # Add activation for input to first hidden activations = [activations[0]] + activations # Confirm network button if st.button("Confirm Network"): st.session_state['hidden_sizes'] = hidden_sizes st.session_state['activations'] = activations st.session_state['network_confirmed'] = True st.session_state['training_complete'] = False st.session_state['testing_complete'] = False # Show network configuration if st.session_state.get('network_confirmed', False): st.subheader("Network Architecture Visualization") fig = visualize_network(n_features, hidden_sizes, n_classes) st.pyplot(fig) st.write(f"Input: {n_features} | Hidden: {hidden_sizes} | Output: {n_classes}") st.write(f"Activations: {st.session_state['activations']}") # Training parameters st.subheader("Training Parameters") epochs = st.slider("Number of Epochs", 10, 200, 50) learning_rate = st.slider("Learning Rate", 0.001, 0.1, 0.01, 0.001) batch_size = st.slider("Batch Size", 8, 128, 32) # Train button if st.button("Train MLP"): model = MLP(n_features, hidden_sizes, n_classes, st.session_state['activations']) train_losses, train_accuracies, val_losses, val_accuracies, weights_history = train_model( model, st.session_state['X_train'], st.session_state['y_train'], st.session_state['X_val'], st.session_state['y_val'], epochs, learning_rate, batch_size, track_weights=True ) st.session_state['model'] = model st.session_state['train_losses'] = train_losses st.session_state['train_accuracies'] = train_accuracies st.session_state['val_losses'] = val_losses st.session_state['val_accuracies'] = val_accuracies st.session_state['weights_history'] = weights_history st.session_state['training_complete'] = True st.session_state['testing_complete'] = False # Show training results if st.session_state.get('training_complete', False): st.header("Training Results") fig = plot_training_history( st.session_state['train_losses'], st.session_state['train_accuracies'], st.session_state['val_losses'], st.session_state['val_accuracies'] ) st.pyplot(fig) st.subheader("Weight Visualization (All Layers)") weight_fig = visualize_weights(st.session_state['model']) st.pyplot(weight_fig) st.subheader("Weight Optimization (First Layer)") opt_fig = plot_weight_optimization(st.session_state['weights_history']) st.pyplot(opt_fig) col1, col2, col3, col4 = st.columns(4) with col1: st.metric("Final Training Loss", f"{st.session_state['train_losses'][-1]:.4f}") with col2: st.metric("Final Training Accuracy", f"{st.session_state['train_accuracies'][-1]:.2%}") with col3: st.metric("Final Validation Loss", f"{st.session_state['val_losses'][-1]:.4f}") with col4: st.metric("Final Validation Accuracy", f"{st.session_state['val_accuracies'][-1]:.2%}") # Test button if st.button("Test on Unseen Data"): model = st.session_state['model'] X_test = st.session_state['X_test'] y_test = st.session_state['y_test'] model.eval() with torch.no_grad(): X_tensor = torch.FloatTensor(X_test) outputs = model(X_tensor) _, predicted = torch.max(outputs.data, 1) test_accuracy = (predicted.numpy() == y_test).mean() st.session_state['test_accuracy'] = test_accuracy st.session_state['test_predictions'] = predicted.numpy() st.session_state['testing_complete'] = True if st.session_state.get('testing_complete', False): st.header("Test Results") st.success(f"Test Accuracy: {st.session_state['test_accuracy']:.2%}") # Confusion Matrix st.subheader("Confusion Matrix") cm_fig = plot_confusion_matrix( st.session_state['y_test'], st.session_state['test_predictions'], n_classes ) st.pyplot(cm_fig) # Classification Metrics st.subheader("Classification Metrics") metrics_df = plot_classification_metrics( st.session_state['y_test'], st.session_state['test_predictions'], n_classes ) st.dataframe(metrics_df) # Additional Test Metrics st.subheader("Additional Test Metrics") col1, col2 = st.columns(2) with col1: st.metric("Test Accuracy", f"{st.session_state['test_accuracy']:.2%}") with col2: st.metric("Test Error Rate", f"{1 - st.session_state['test_accuracy']:.2%}") else: st.info("Please generate a dataset using the sidebar controls to begin.")