File size: 5,749 Bytes
f8ea024 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# Alzheimer's Prediction App with Random Forest Classifier
# -----------------------------------------------------------
# Made with β€οΈ for the contest, featuring long code, animations, and brain π§ emojis.
# Importing Libraries
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
classification_report,
confusion_matrix,
roc_curve,
auc,
)
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from plotly import graph_objs as go
import time
# Set up the Streamlit app
st.set_page_config(
page_title="π§ Alzheimer's Detection",
page_icon="π§ ",
layout="wide",
)
# Add loading animation
with st.spinner("π App is loading... Please wait!"):
time.sleep(2) # Simulating loading time
# Title and Description
st.title("π§ Alzheimer's Disease Prediction")
st.markdown(
"""
Welcome to the **Alzheimer's Disease Prediction App**! This tool uses a **Random Forest Classifier**
to predict whether a patient has Alzheimer's disease based on clinical data.
---
"""
)
# Sidebar for uploading data
st.sidebar.header("Upload Dataset")
uploaded_file = st.sidebar.file_uploader(
"Upload your CSV file containing the dataset", type=["csv"]
)
# Default dataset (if no file uploaded)
if uploaded_file is not None:
df = pd.read_csv(uploaded_file)
st.sidebar.success("β
Dataset loaded successfully!")
else:
st.sidebar.warning("β οΈ Please upload a dataset to proceed!")
st.stop()
# Display the dataset
st.write("### Dataset Overview")
st.dataframe(df.head())
# Preprocessing the data
st.write("### Data Preprocessing")
with st.spinner("π Preprocessing data..."):
time.sleep(1) # Simulate processing delay
# Dropping duplicates
initial_rows = df.shape[0]
df = df.drop_duplicates()
final_rows = df.shape[0]
st.write(f"ποΈ Removed {initial_rows - final_rows} duplicate rows.")
# Checking for missing values
missing_values = df.isnull().sum()
st.write("#### Missing Values:")
st.write(missing_values[missing_values > 0])
# Fill or drop missing values
df.fillna(df.mean(), inplace=True)
# Splitting features and target
target = st.sidebar.selectbox("Select the Target Column", df.columns)
X = df.drop(columns=[target])
y = df[target]
# Splitting data into training and testing sets
test_size = st.sidebar.slider("Test Data Size (%)", min_value=10, max_value=50, value=20)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size / 100, random_state=42)
# Scaling the data
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# Feature selection
st.write("#### Feature Importance Visualization")
with st.spinner("π Generating feature importances..."):
rf = RandomForestClassifier(random_state=42)
rf.fit(X_train_scaled, y_train)
feature_importances = rf.feature_importances_
# Plotting feature importances
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(x=feature_importances, y=X.columns, ax=ax)
ax.set_title("Feature Importance")
ax.set_xlabel("Importance Score")
ax.set_ylabel("Features")
st.pyplot(fig)
# Training the Random Forest Classifier
st.write("### Training the Model")
with st.spinner("π§ Training the Random Forest Classifier..."):
time.sleep(2)
rf = RandomForestClassifier(random_state=42)
rf.fit(X_train_scaled, y_train)
st.success("π Model trained successfully!")
# Model Evaluation
st.write("### Model Evaluation")
y_pred = rf.predict(X_test_scaled)
accuracy = accuracy_score(y_test, y_pred)
st.write(f"**Accuracy:** {accuracy * 100:.2f}%")
st.write("#### Classification Report")
st.text(classification_report(y_test, y_pred))
# Confusion Matrix
st.write("#### Confusion Matrix")
conf_matrix = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", ax=ax)
ax.set_title("Confusion Matrix")
ax.set_xlabel("Predicted")
ax.set_ylabel("Actual")
st.pyplot(fig)
# ROC Curve
st.write("#### ROC Curve")
y_prob = rf.predict_proba(X_test_scaled)[:, 1]
fpr, tpr, thresholds = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}", color="darkorange")
ax.plot([0, 1], [0, 1], "r--")
ax.set_title("Receiver Operating Characteristic (ROC) Curve")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(loc="lower right")
st.pyplot(fig)
# Additional Graphs
st.write("### Additional Visualizations")
with st.spinner("π Generating more graphs..."):
# Pairplot
st.write("#### Pairplot")
pairplot_fig = sns.pairplot(df, hue=target)
st.pyplot(pairplot_fig)
# Correlation Heatmap
st.write("#### Correlation Heatmap")
fig, ax = plt.subplots(figsize=(10, 8))
sns.heatmap(df.corr(), annot=True, cmap="coolwarm", ax=ax)
ax.set_title("Correlation Matrix")
st.pyplot(fig)
# Histogram for each feature
st.write("#### Feature Distributions")
for col in X.columns:
fig, ax = plt.subplots(figsize=(6, 4))
sns.histplot(df[col], kde=True, ax=ax)
ax.set_title(f"Distribution of {col}")
st.pyplot(fig)
# Save the model
st.sidebar.write("### Save Model")
save_model = st.sidebar.button("Save Model")
if save_model:
import joblib
joblib.dump(rf, "alzheimers_model.pkl")
st.sidebar.success("π Model saved as 'alzheimers_model.pkl'!")
# End of the app
st.write("---")
st.write("π§ Thank you for using the Alzheimer's Disease Prediction App!")
st.balloons() |