disease / app.py
drwaseem's picture
Update app.py
59edd4a verified
import pandas as pd
import numpy as np
import streamlit as st
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
accuracy_score,
f1_score,
confusion_matrix,
roc_curve,
auc,
precision_recall_curve,
)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.datasets import make_classification
from mpl_toolkits.mplot3d import Axes3D
# Streamlit Configuration
st.set_page_config(
page_title="๐Ÿง  Alzheimer's Diagnosis App",
page_icon="๐Ÿ’ก",
layout="wide",
)
# App Title with More Color and Brain Emoji
st.title("๐Ÿง  Early Diagnosis of Alzheimer's Disease ๐Ÿง ")
st.subheader("๐ŸŒŸ Empowering early intervention for a healthier future! ๐ŸŒŸ")
# Load Dataset
uploaded_file = st.file_uploader("๐Ÿ“‚ Upload your dataset (CSV format)", type=["csv"])
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
st.success("โœ… Dataset loaded successfully! ๐Ÿง ")
else:
# Generate Synthetic Data if no file is uploaded
st.warning("โš ๏ธ No file uploaded. Using synthetic data. ๐Ÿง ")
X, y = make_classification(
n_samples=1000,
n_features=10,
n_informative=5,
n_redundant=2,
n_classes=2,
random_state=42,
)
columns = [f"Feature_{i}" for i in range(X.shape[1])]
data = pd.DataFrame(X, columns=columns)
data["AlzheimerRisk"] = y
# Display Full Dataset (250 rows)
st.write("### ๐Ÿ” Dataset Preview ๐Ÿง ")
st.write(data.head(250))
if "AlzheimerRisk" not in data.columns:
st.error("โŒ Dataset must contain a column named 'AlzheimerRisk'. ๐Ÿง ")
else:
# Data Preprocessing
st.write("### ๐Ÿ›  Data Preprocessing ๐Ÿง ")
# Encode categorical columns
label_encoders = {}
for col in data.select_dtypes(include=["object"]).columns:
label_encoders[col] = LabelEncoder()
data[col] = label_encoders[col].fit_transform(data[col])
# Display full processed dataset (250 rows)
st.write("โœ… Preprocessed Dataset ๐Ÿง ", data.head(250))
# Ensure that Alzheimer's Risk is binary
if data['AlzheimerRisk'].dtype != 'int' and data['AlzheimerRisk'].dtype != 'bool':
# If Alzheimer's risk is continuous, binarize it (for classification purposes)
st.write("โš ๏ธ Binarizing 'AlzheimerRisk' to binary classification. ๐Ÿง ")
data['AlzheimerRisk'] = (data['AlzheimerRisk'] >= 0.5).astype(int)
# Select Features and Target
features = [col for col in data.columns if col != "AlzheimerRisk"]
X = data[features]
y = data["AlzheimerRisk"]
# Split Data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Train RandomForestClassifier (Improved Accuracy)
rf_model = RandomForestClassifier(random_state=42, n_estimators=200, max_depth=10)
rf_model.fit(X_train, y_train)
# Evaluate Model
y_pred = rf_model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
# Display Metrics
st.metric("๐ŸŽฏ Accuracy ๐Ÿง ", f"{accuracy*100:.2f}%")
st.metric("๐Ÿ“Š F1 Score ๐Ÿง ", f"{f1:.2f}")
# Add AGE Distribution Plot
st.write("### ๐Ÿ“Š Age Distribution ๐Ÿง ")
if "Age" in data.columns:
plt.figure(figsize=(10, 6))
sns.histplot(data['Age'], kde=True, color='dodgerblue', bins=20)
plt.title("Age Distribution ๐Ÿง ")
st.pyplot(plt.gcf())
plt.clf()
else:
st.warning("โš ๏ธ Age column not found in the dataset! ๐Ÿง ")
# Confusion Matrix
st.write("### ๐Ÿ“Š Confusion Matrix ๐Ÿง ")
cm = confusion_matrix(y_test, y_pred)
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=["No Risk", "At Risk"], yticklabels=["No Risk", "At Risk"])
plt.title("Confusion Matrix ๐Ÿง ")
plt.ylabel("True label ๐Ÿง ")
plt.xlabel("Predicted label ๐Ÿง ")
st.pyplot(fig)
plt.clf()
# Feature Importance
st.write("### ๐Ÿ“Š Feature Importance ๐Ÿง ")
feature_importances = rf_model.feature_importances_
sorted_idx = np.argsort(feature_importances)[::-1]
sorted_features = np.array(features)[sorted_idx]
sorted_importances = feature_importances[sorted_idx]
# Plot Feature Importance
plt.figure(figsize=(10, 6))
sns.barplot(x=sorted_importances, y=sorted_features, palette="viridis")
plt.title("Feature Importance ๐Ÿง ")
plt.xlabel("Importance Score ๐Ÿง ")
plt.ylabel("Features ๐Ÿง ")
st.pyplot(plt.gcf())
plt.clf()
# Add Line Graph and Area Graph
st.write("### ๐Ÿ“ˆ Line Graph ๐Ÿง ")
line_feature = st.selectbox("Select feature for Line Graph:", features)
plt.figure(figsize=(10, 6))
sns.lineplot(data=data, x=data.index, y=line_feature, color="green")
plt.title(f"Line Graph of {line_feature} ๐Ÿง ")
plt.xlabel("Index ๐Ÿง ")
plt.ylabel(line_feature)
st.pyplot(plt.gcf())
plt.clf()
st.write("### ๐Ÿ“‰ Area Graph ๐Ÿง ")
area_feature = st.selectbox("Select feature for Area Graph:", features)
plt.figure(figsize=(10, 6))
sns.lineplot(data=data, x=data.index, y=area_feature, color="orange", linewidth=2)
plt.fill_between(data.index, data[area_feature], color="orange", alpha=0.3)
plt.title(f"Area Graph of {area_feature} ๐Ÿง ")
plt.xlabel("Index ๐Ÿง ")
plt.ylabel(area_feature)
st.pyplot(plt.gcf())
plt.clf()
# Visualizations
st.write("### ๐Ÿ“Š Data Visualizations ๐Ÿง ")
visualization_type = st.selectbox(
"Choose a visualization type ๐Ÿง :",
["2D Scatter Plot", "3D Scatter Plot", "Bar Chart", "Pie Chart", "Histogram"],
)
if visualization_type == "2D Scatter Plot":
x_col = st.selectbox("Select X-axis feature ๐Ÿง :", features)
y_col = st.selectbox("Select Y-axis feature ๐Ÿง :", features)
plt.figure(figsize=(10, 6))
sns.scatterplot(data=data, x=x_col, y=y_col, hue="AlzheimerRisk", palette="viridis")
plt.title("2D Scatter Plot ๐Ÿง ")
st.pyplot(plt.gcf())
plt.clf()
elif visualization_type == "3D Scatter Plot":
x_col = st.selectbox("Select X-axis feature ๐Ÿง :", features)
y_col = st.selectbox("Select Y-axis feature ๐Ÿง :", features)
z_col = st.selectbox("Select Z-axis feature ๐Ÿง :", features)
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection="3d")
scatter = ax.scatter(
data[x_col], data[y_col], data[z_col], c=data["AlzheimerRisk"], cmap="viridis", s=50
)
ax.set_xlabel(x_col)
ax.set_ylabel(y_col)
ax.set_zlabel(z_col)
plt.colorbar(scatter, label="AlzheimerRisk ๐Ÿง ")
st.pyplot(fig)
plt.clf()
elif visualization_type == "Bar Chart":
bar_feature = st.selectbox("Select feature for Bar Chart ๐Ÿง :", features)
plt.figure(figsize=(10, 6))
data.groupby(bar_feature)["AlzheimerRisk"].mean().plot(kind="bar", color="skyblue")
plt.title("Bar Chart of Risk by Feature ๐Ÿง ")
plt.xlabel(bar_feature)
plt.ylabel("Average Risk ๐Ÿง ")
st.pyplot(plt.gcf())
plt.clf()
elif visualization_type == "Pie Chart":
pie_counts = data["AlzheimerRisk"].value_counts()
plt.figure(figsize=(8, 8))
plt.pie(
pie_counts,
labels=["No Risk ๐Ÿง ", "At Risk ๐Ÿง "],
autopct="%1.1f%%",
startangle=140,
colors=["green", "red"],
)
plt.title("Distribution of Alzheimer's Risk ๐Ÿง ")
st.pyplot(plt.gcf())
plt.clf()
elif visualization_type == "Histogram":
hist_feature = st.selectbox("Select feature for Histogram ๐Ÿง :", features)
plt.figure(figsize=(10, 6))
sns.histplot(data=data, x=hist_feature, hue="AlzheimerRisk", kde=True, palette="viridis")
plt.title("Histogram ๐Ÿง ")
st.pyplot(plt.gcf())
plt.clf()
# ROC Curve
st.write("### ๐Ÿ“ˆ ROC Curve ๐Ÿง ")
y_proba = rf_model.predict_proba(X_test)[:, 1]
fpr, tpr, _ = roc_curve(y_test, y_proba)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(10, 6))
plt.plot(fpr, tpr, color="blue", lw=2, label=f"ROC Curve (AUC = {roc_auc:.2f}) ๐Ÿง ")
plt.plot([0, 1], [0, 1], color="gray", linestyle="--")
plt.xlabel("False Positive Rate ๐Ÿง ")
plt.ylabel("True Positive Rate ๐Ÿง ")
plt.title("Receiver Operating Characteristic (ROC) Curve ๐Ÿง ")
plt.legend(loc="lower right")
st.pyplot(plt.gcf())
plt.clf()
# Precision-Recall Curve
st.write("### ๐Ÿ“‰ Precision-Recall Curve ๐Ÿง ")
precision, recall, _ = precision_recall_curve(y_test, y_proba)
plt.figure(figsize=(10, 6))
plt.plot(recall, precision, color="green", lw=2, label="Precision-Recall Curve ๐Ÿง ")
plt.xlabel("Recall ๐Ÿง ")
plt.ylabel("Precision ๐Ÿง ")
plt.title("Precision-Recall Curve ๐Ÿง ")
plt.legend(loc="upper right")
st.pyplot(plt.gcf())
plt.clf()
# Prediction Tab
st.write("### ๐Ÿงฎ Predict Alzheimer's Risk ๐Ÿง ")
input_data = {}
for feature in features:
if feature in label_encoders: # For categorical features
input_data[feature] = st.selectbox(f"{feature} ๐Ÿ”ฝ", label_encoders[feature].classes_)
input_data[feature] = label_encoders[feature].transform([input_data[feature]])[0]
else: # For numeric features
input_data[feature] = st.number_input(f"{feature} โœ๏ธ", value=float(data[feature].mean()))
# Predict Risk
input_df = pd.DataFrame([input_data])
prediction = rf_model.predict(input_df)[0]
prediction_proba = rf_model.predict_proba(input_df)[0]
# Display Prediction
st.write("### ๐Ÿฉบ Prediction Result ๐Ÿง ")
if prediction == 1:
st.error(f"๐Ÿšจ The person is **at risk** of Alzheimer's Disease ๐Ÿง .")
else:
st.success(f"โœ… The person is **not at risk** of Alzheimer's Disease ๐Ÿง .")
st.write(f"๐Ÿ” Prediction Confidence ๐Ÿง : {prediction_proba[prediction]:.2f}")