Sathwikchowdary commited on
Commit
aec340e
ยท
verified ยท
1 Parent(s): 3b6db41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -0
app.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import seaborn as sns
5
+ import matplotlib.pyplot as plt
6
+ from sklearn.preprocessing import StandardScaler
7
+ from sklearn.model_selection import train_test_split
8
+ from keras.models import Sequential
9
+ from keras.layers import InputLayer, Dense
10
+ from sklearn.datasets import make_circles, make_classification, make_moons, make_blobs
11
+ from mlxtend.plotting import plot_decision_regions
12
+ from keras.optimizers import SGD
13
+ import time
14
+
15
+ # Custom background and styling
16
+ st.markdown(
17
+ """
18
+ <style>
19
+ .main {
20
+ background: linear-gradient(to right, #f0f4f8, #d9e2ec);
21
+ }
22
+ </style>
23
+ """,
24
+ unsafe_allow_html=True
25
+ )
26
+
27
+ # App Title
28
+ st.title("๐Ÿง  NeuroVision Lab - Interactive Neural Network Playground")
29
+
30
+ # Sidebar: Dataset selection
31
+ st.sidebar.header("๐ŸŽฒ Generate Synthetic Data")
32
+ data_type = st.sidebar.selectbox("Select Dataset Type", ["make_circles", "make_classification", "make_moons", "make_blobs"])
33
+ factor = st.sidebar.slider("Circle Factor (for make_circles)", 0.1, 1.0, 0.2)
34
+ noise = st.sidebar.slider("Add Noise", 0.0, 1.0, 0.1)
35
+ samples = st.sidebar.slider("Total Samples", 1000, 10000, 10000, step=100)
36
+
37
+ generate_scatter = st.sidebar.button("๐Ÿ”„ Create Dataset")
38
+
39
+ # Initialize session state
40
+ if 'X' not in st.session_state:
41
+ st.session_state['X'] = None
42
+ if 'y' not in st.session_state:
43
+ st.session_state['y'] = None
44
+
45
+ # Function to generate data
46
+ def generate_data(data_type, samples, noise, factor):
47
+ if data_type == "make_circles":
48
+ st.session_state['X'], st.session_state['y'] = make_circles(n_samples=samples, noise=noise, factor=factor, random_state=42)
49
+ elif data_type == "make_classification":
50
+ st.session_state['X'], st.session_state['y'] = make_classification(n_samples=samples, n_features=2, n_informative=2,
51
+ n_redundant=0, n_clusters_per_class=1, flip_y=noise, random_state=42)
52
+ elif data_type == "make_moons":
53
+ st.session_state['X'], st.session_state['y'] = make_moons(n_samples=samples, noise=noise, random_state=42)
54
+ elif data_type == "make_blobs":
55
+ st.session_state['X'], st.session_state['y'] = make_blobs(n_samples=samples, centers=2, cluster_std=1.0, random_state=42)
56
+
57
+ # Scatterplot of generated data
58
+ if generate_scatter:
59
+ generate_data(data_type, samples, noise, factor)
60
+ if st.session_state['X'] is not None and st.session_state['y'] is not None:
61
+ df = pd.DataFrame(st.session_state['X'], columns=["x1", "x2"])
62
+ df["label"] = st.session_state['y']
63
+
64
+ st.subheader(f"๐Ÿงฉ Visualizing: {data_type}")
65
+ fig1, ax1 = plt.subplots()
66
+ sns.scatterplot(data=df, x="x1", y="x2", hue="label", palette="viridis", ax=ax1)
67
+ st.pyplot(fig1)
68
+ else:
69
+ st.warning("Data generation unsuccessful. Please check your parameters.")
70
+
71
+ # Sidebar: Training Configuration
72
+ st.sidebar.header("โš™๏ธ Model Configuration")
73
+ test_percent = st.sidebar.slider("Test Set (%)", 10, 90, 20)
74
+ test_size = test_percent / 100
75
+ learning_rate = st.sidebar.selectbox("Choose Learning Rate", [0.0001, 0.001, 0.01, 0.1])
76
+ act_fun = st.sidebar.selectbox("Activation Function", ["sigmoid", "tanh", "relu"])
77
+ batch_size = st.sidebar.slider("Batch Size", 1, 10000, 6400)
78
+ epochs = st.sidebar.slider("Training Epochs", 1, 1000, 600)
79
+
80
+ # Train Model and Plot Decision Surface
81
+ if st.sidebar.button("๐Ÿงฎ Train Model & Show Decision Surface"):
82
+ if st.session_state['X'] is None or st.session_state['y'] is None:
83
+ st.error("โš ๏ธ Please generate a dataset first.")
84
+ else:
85
+ # Preprocessing
86
+ x_train, x_test, y_train, y_test = train_test_split(st.session_state['X'], st.session_state['y'], test_size=test_size, stratify=st.session_state['y'], random_state=1)
87
+ scaler = StandardScaler()
88
+ x_train = scaler.fit_transform(x_train)
89
+ x_test = scaler.transform(x_test)
90
+
91
+ # Build model
92
+ model = Sequential()
93
+ model.add(InputLayer(input_shape=(2,)))
94
+ for units in [4, 2, 2]:
95
+ model.add(Dense(units, activation=act_fun))
96
+ model.add(Dense(1, activation="sigmoid"))
97
+
98
+ sgd = SGD(learning_rate=learning_rate)
99
+ model.compile(optimizer=sgd, loss="binary_crossentropy", metrics=["accuracy"])
100
+
101
+ # Show training progress
102
+ st.subheader("๐Ÿ” Model Training Progress")
103
+ progress_bar = st.progress(0)
104
+ progress_pct = st.empty()
105
+
106
+ history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=0, validation_split=0.2)
107
+ for epoch in range(epochs):
108
+ progress = int((epoch + 1) / epochs * 100)
109
+ progress_bar.progress(progress)
110
+ progress_pct.write(f"{progress}%")
111
+ time.sleep(0.01)
112
+
113
+ # Decision surface visualization
114
+ st.subheader("๐Ÿง  Neural Network Decision Boundary")
115
+ fig2, ax2 = plt.subplots()
116
+ plot_decision_regions(x_train, y_train, clf=model, legend=2, ax=ax2)
117
+ st.pyplot(fig2)
118
+
119
+ st.session_state['history'] = history
120
+
121
+ # Show Loss Curve
122
+ if st.sidebar.button("๐Ÿ“‰ Display Loss Curve"):
123
+ if 'history' in st.session_state:
124
+ st.subheader("๐Ÿ“‰ Training vs Validation Loss")
125
+ history = st.session_state['history']
126
+ fig3, ax3 = plt.subplots()
127
+ ax3.plot(history.history['loss'], label='Train Loss')
128
+ ax3.plot(history.history['val_loss'], label='Val Loss')
129
+ ax3.set_xlabel("Epochs")
130
+ ax3.set_ylabel("Loss")
131
+ ax3.set_title("Loss Progress Over Time")
132
+ ax3.legend()
133
+ st.pyplot(fig3)
134
+ else:
135
+ st.warning("โณ Train the model to visualize the loss curve.")