Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
from io import StringIO | |
from sklearn.impute import KNNImputer, SimpleImputer | |
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, LabelEncoder, OneHotEncoder | |
from sklearn.decomposition import PCA | |
from sklearn.cluster import KMeans | |
from sklearn.model_selection import train_test_split | |
from pycaret.classification import setup, compare_models, pull | |
from scipy.stats import zscore | |
import matplotlib | |
from sklearn.feature_selection import SelectKBest, f_classif | |
from ydata_profiling import ProfileReport | |
from ydata_profiling.config import Settings | |
from functools import lru_cache | |
# ================== ๐น ENHANCED STYLING ================== | |
def load_custom_css(): | |
st.markdown(""" | |
<style> | |
/* ๐ Cosmic Nebula Background */ | |
body, .main { | |
background: radial-gradient(circle at top, #10002b 0%, #240046 50%, #3c096c 100%); | |
color: #ffffff; | |
font-family: 'Poppins', sans-serif; | |
} | |
/* ๐ Animated Starfield Effect */ | |
body::before { | |
content: ""; | |
position: fixed; | |
top: 0; | |
left: 0; | |
width: 100%; | |
height: 100%; | |
background: url('https://source.unsplash.com/random/1600x900/?stars,galaxy,nebula') center/cover no-repeat; | |
opacity: 0.1; | |
z-index: -1; | |
} | |
/* ๐ช Glassmorphism Containers */ | |
.stContainer, .stExpander, .stDataFrame { | |
background: rgba(255, 255, 255, 0.08) !important; | |
backdrop-filter: blur(15px); | |
border-radius: 15px; | |
border: 1px solid rgba(255, 255, 255, 0.12); | |
padding: 1.5rem; | |
box-shadow: 0 10px 30px rgba(255, 255, 255, 0.12); | |
} | |
/* ๐ฎ Cyberpunk Buttons */ | |
.stButton>button { | |
background: linear-gradient(90deg, #ff00ff, #00ffff); | |
color: white !important; | |
border: none; | |
border-radius: 12px; | |
padding: 0.8rem 1.5rem; | |
font-weight: bold; | |
letter-spacing: 0.05rem; | |
transition: all 0.4s ease; | |
text-transform: uppercase; | |
width: 100%; | |
} | |
.stButton>button:hover { | |
transform: scale(1.05); | |
box-shadow: 0 0 20px rgba(0, 255, 255, 0.8); | |
} | |
/* ๐ Neon Headers */ | |
h1, h2, h3, h4, h5, h6 { | |
font-weight: bold; | |
text-transform: uppercase; | |
text-shadow: 0 0 10px rgba(0, 255, 255, 0.6); | |
color: #00ffff; | |
padding: 0.5rem 0; | |
} | |
/* ๐ Interactive Inputs */ | |
.stTextInput>div>div>input, | |
.stSelectbox>div>div>div, | |
.stSlider>div>div>div { | |
background: rgba(0, 0, 0, 0.5) !important; | |
border-radius: 10px !important; | |
padding: 0.75rem !important; | |
color: white !important; | |
border: 1px solid rgba(255, 255, 255, 0.3) !important; | |
transition: all 0.3s ease; | |
} | |
.stTextInput>div>div>input:focus, | |
.stSelectbox>div>div>div:hover { | |
border-color: #ff00ff !important; | |
box-shadow: 0 0 12px rgba(255, 0, 255, 0.6); | |
} | |
/* ๐ญ Data Grid Styling */ | |
[data-testid="stDataFrame"] { | |
border: 1px solid rgba(255, 255, 255, 0.2); | |
border-radius: 10px; | |
background: rgba(255, 255, 255, 0.05); | |
padding: 1rem; | |
color: white !important; | |
} | |
/* ๐ Graph Enhancements */ | |
.stPlotlyChart, .stPydeckChart { | |
border-radius: 15px; | |
border: 1px solid rgba(255, 255, 255, 0.1); | |
padding: 1rem; | |
box-shadow: 0 8px 20px rgba(255, 255, 255, 0.15); | |
} | |
/* ๐๏ธ Consistent Spacing */ | |
.stContainer > *, | |
.stExpander > * { | |
margin: 1rem 0; | |
} | |
/* ๐ Futuristic Scrollbars */ | |
::-webkit-scrollbar { | |
width: 8px; | |
height: 8px; | |
} | |
::-webkit-scrollbar-track { | |
background: rgba(25, 25, 45, 0.5); | |
} | |
::-webkit-scrollbar-thumb { | |
background: linear-gradient(180deg, #ff00ff, #00ffff); | |
border-radius: 4px; | |
box-shadow: 0 0 10px rgba(255, 255, 255, 0.3); | |
} | |
/* โจ Smooth Animations */ | |
* { | |
transition: all 0.25s ease-in-out; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
load_custom_css() | |
# ================== ๐น CACHED FUNCTIONS ================== | |
# ================== ๐น CACHED FUNCTIONS ================== | |
def calculate_statistics(df, column): | |
"""Calculate and cache statistics for a column.""" | |
if pd.api.types.is_numeric_dtype(df[column]): | |
return { | |
"mean": df[column].mean(), | |
"median": df[column].median(), | |
"std": df[column].std(), | |
"min": df[column].min(), | |
"max": df[column].max() | |
} | |
else: | |
return { | |
"unique_values": df[column].nunique(), | |
"most_common": df[column].mode()[0] | |
} | |
def generate_chart(df, chart_type, x_col, y_col=None, z_col=None): | |
"""Generate and cache Plotly charts.""" | |
if chart_type == "Histogram": | |
return px.histogram(df, x=x_col, nbins=30, title=f"Distribution of {x_col}", | |
color_discrete_sequence=['#00cc96'], template="plotly_dark") | |
elif chart_type == "Box Plot": | |
return px.box(df, y=x_col, title=f"Box Plot of {x_col}", | |
color_discrete_sequence=['#ff7f0e'], template="plotly_dark") | |
elif chart_type == "Violin Plot": | |
return px.violin(df, y=x_col, title=f"Violin Plot of {x_col}", | |
color_discrete_sequence=['#9467bd'], template="plotly_dark") | |
elif chart_type == "Scatter Plot": | |
return px.scatter(df, x=x_col, y=y_col, title=f"{x_col} vs {y_col}", | |
color_discrete_sequence=['#1f77b4'], template="plotly_dark") | |
elif chart_type == "3D Scatter": | |
return px.scatter_3d(df, x=x_col, y=y_col, z=z_col, | |
title=f"3D Analysis: {x_col} vs {y_col} vs {z_col}", | |
color_discrete_sequence=['#2ca02c'], template="plotly_dark") | |
elif chart_type == "Heatmap": | |
corr_matrix = df[[x_col, y_col]].corr() | |
return px.imshow(corr_matrix, text_auto=True, title="Correlation Heatmap", | |
color_continuous_scale='Viridis', template="plotly_dark") | |
# ================== ๐น LAZY-LOADING COMPONENTS ================== | |
def lazy_load_chart(df, chart_type, x_col, y_col=None): | |
"""Lazy-load a chart with a spinner.""" | |
with st.spinner(f"Generating {chart_type}..."): | |
return generate_chart(df, chart_type, x_col, y_col) | |
def lazy_load_statistics(df, column): | |
"""Lazy-load statistics with a spinner.""" | |
with st.spinner("Calculating statistics..."): | |
return calculate_statistics(df, column) | |
# ================== ๐น SESSION STATE ================== | |
if 'df' not in st.session_state: | |
st.session_state.df = None | |
if 'cleaned_df' not in st.session_state: | |
st.session_state.cleaned_df = None | |
if 'X_train' not in st.session_state: | |
st.session_state.X_train = None | |
if 'X_test' not in st.session_state: | |
st.session_state.X_test = None | |
if 'y_train' not in st.session_state: | |
st.session_state.y_train = None | |
if 'y_test' not in st.session_state: | |
st.session_state.y_test = None | |
if 'model' not in st.session_state: | |
st.session_state.model = None | |
# ================== ๐น GLOBAL NAVIGATION ================== | |
st.sidebar.title("๐ Nexus Analytics") | |
choice = st.sidebar.radio("Go to", ["Home", "Data Cleaning", "EDA", "Train-Test Split", | |
"Machine Learning", "Predictions", "Visualization"]) | |
if choice == "Home": | |
st.title("๐ Upload Your Dataset") | |
# Dataset Control Buttons | |
control_col1, control_col2 = st.columns([1, 2]) | |
with control_col1: | |
if st.session_state.df is not None: | |
if st.button("๐งน Clear Dataset", help="Remove current dataset from memory"): | |
st.session_state.df = None | |
st.session_state.cleaned_df = None | |
st.success("Dataset cleared from memory!") | |
with control_col2: | |
replace_file = st.file_uploader("Replace Dataset", type=["csv", "xlsx"], | |
help="Upload a new dataset to replace current one", | |
key="replace_uploader") | |
if replace_file: | |
df = pd.read_csv(replace_file) if replace_file.name.endswith('.csv') else pd.read_excel(replace_file) | |
st.session_state.df = df | |
st.session_state.cleaned_df = df.copy() | |
st.success("โ Dataset replaced successfully!") | |
# Main Dataset Upload | |
if st.session_state.df is None: | |
with st.container(): | |
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx"], | |
help="Drag and drop your dataset file here") | |
if uploaded_file: | |
df = pd.read_csv(uploaded_file) if uploaded_file.name.endswith('.csv') else pd.read_excel(uploaded_file) | |
st.session_state.df = df | |
st.session_state.cleaned_df = df.copy() | |
st.success("โ Data uploaded successfully!") | |
# Show dataset information if loaded | |
if st.session_state.df is not None: | |
df = st.session_state.df | |
# Dataset Overview Cards | |
with st.container(): | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
with st.container(): | |
st.markdown("### ๐ Dataset Shape") | |
st.markdown(f"**{df.shape[0]}** Rows | **{df.shape[1]}** Columns") | |
with col2: | |
with st.container(): | |
st.markdown("### โ ๏ธ Data Issues") | |
st.markdown(f"**{df.isnull().sum().sum()}** Missing Values | **{df.duplicated().sum()}** Duplicates") | |
with col3: | |
with st.container(): | |
st.markdown("### ๐งฌ Data Types") | |
num_cols = len(df.select_dtypes(include=np.number).columns) | |
cat_cols = len(df.select_dtypes(include=['object']).columns) | |
st.markdown(f"**{num_cols}** Numerical | **{cat_cols}** Categorical") | |
# Automated Data Report | |
with st.expander("๐ Automated Data Report", expanded=True): | |
if st.button("โจ Generate Smart Report"): | |
with st.spinner("๐ Analyzing dataset..."): | |
# Configure minimal report | |
config = Settings() | |
config.title = " " | |
config.variables.descriptions = False | |
config.show_variable_description = False | |
config.samples.head = 0 | |
config.samples.tail = 0 | |
# Generate report with dark mode | |
profile = ProfileReport( | |
df, | |
config=config, | |
minimal=True, | |
) | |
# Apply custom color scheme | |
report_html = profile.to_html() | |
report_html = report_html.replace( | |
':root {', | |
':root { --primary-color: #00f7ff; --secondary-color: #0066ff;' | |
) | |
report_html = report_html.replace('<h1', '<h1 style="display:none"') | |
st.components.v1.html(report_html, height=800, scrolling=True) | |
# Interactive Data Explorer | |
st.subheader("๐ Data Explorer") | |
# Data Samples Tabs | |
with st.expander("๐ Data Samples", expanded=True): | |
sample_type = st.selectbox("View Data Samples", | |
["First 5 Rows", "Last 5 Rows", "Random Sample"], | |
key="sample_selector") | |
if sample_type == "First 5 Rows": | |
st.dataframe(df.head().style.highlight_null(color='#FF6666'), use_container_width=True) | |
elif sample_type == "Last 5 Rows": | |
st.dataframe(df.tail().style.highlight_null(color='#FF6666'), use_container_width=True) | |
else: | |
sample_size = st.slider("Sample Size", 5, min(100, len(df)), 10) | |
st.dataframe(df.sample(sample_size).style.highlight_null(color='#FF6666'), use_container_width=True) | |
# Column Analysis | |
with st.expander("๐ Column Insights", expanded=True): | |
col1, col2 = st.columns(2) | |
with col1: | |
selected_col = st.selectbox("Select Column", df.columns) | |
if pd.api.types.is_numeric_dtype(df[selected_col]): | |
fig = px.histogram(df, x=selected_col, | |
title=f"Distribution of {selected_col}", | |
color_discrete_sequence=['#00f7ff']) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
value_counts = df[selected_col].value_counts().nlargest(10) | |
fig = px.bar(value_counts, | |
title=f"Top 10 Values in {selected_col}", | |
color_discrete_sequence=['#0066ff']) | |
st.plotly_chart(fig, use_container_width=True) | |
with col2: | |
st.markdown("#### Column Summary") | |
st.write(f"**Data Type:** {df[selected_col].dtype}") | |
st.write(f"**Unique Values:** {df[selected_col].nunique()}") | |
if pd.api.types.is_numeric_dtype(df[selected_col]): | |
st.write(f"**Min Value:** {df[selected_col].min():.2f}") | |
st.write(f"**Max Value:** {df[selected_col].max():.2f}") | |
st.write(f"**Mean Value:** {df[selected_col].mean():.2f}") | |
else: | |
st.write("**Most Common Value:**") | |
st.write(df[selected_col].mode()[0]) | |
# Data Summary Tabs | |
tab1, tab2, tab3 = st.tabs(["๐ Full Summary", "๐ Statistics", "๐ง AI Insights"]) | |
with tab1: | |
buffer = StringIO() | |
df.info(buf=buffer) | |
st.text(buffer.getvalue()) | |
with tab2: | |
st.write(df.describe().style.background_gradient(cmap='Blues')) | |
with tab3: | |
st.markdown("### Automated Insights") | |
if st.button("๐ฎ Generate AI-Powered Insights"): | |
with st.spinner("๐ค Analyzing patterns..."): | |
profile = ProfileReport(df, minimal=True) | |
st.write(profile.to_html(), unsafe_allow_html=True) | |
# ================== ๐น ENHANCED DATA CLEANING SECTION ================== | |
elif choice == "Data Cleaning": | |
st.header("๐งผ Intelligent Data Wrangling") | |
if st.session_state.df is not None: | |
df = st.session_state.cleaned_df.copy() | |
# AI-Powered Cleaning Assistant | |
st.subheader("๐ค Smart Cleaning Advisor") | |
if st.button("Run Full Data Diagnosis", type="primary"): | |
with st.spinner("๐ Performing multidimensional analysis..."): | |
try: | |
# Advanced data quality assessment | |
numeric_cols = df.select_dtypes(include=np.number).columns | |
diagnosis = pd.DataFrame({ | |
'Metric': ['Missing Values', 'Duplicate Rows', | |
'Zero Variance', 'Data Leakage Risk'], | |
'Value': [ | |
f"{df.isnull().sum().sum()} ({df.isnull().mean().mean():.1%})", | |
df.duplicated().sum(), | |
df[numeric_cols].std()[df[numeric_cols].std() == 0].count(), | |
"High" if df.skew().abs().max() > 5 else "Low" | |
], | |
'Severity': ['Critical' if df.isnull().sum().sum() > 0 else 'OK', | |
'Warning' if df.duplicated().sum() > 0 else 'OK', | |
'Critical' if df[numeric_cols].std()[df[numeric_cols].std() == 0].count() > 0 else 'OK', | |
'Warning' if df.skew().abs().max() > 5 else 'OK'] | |
}) | |
# Visualize data health | |
fig = px.bar(diagnosis, x='Metric', y='Value', color='Severity', | |
color_discrete_map={'Critical':'#ff2b2b','Warning':'#f0c929','OK':'#00ff87'}, | |
template="plotly_dark") | |
st.plotly_chart(fig, use_container_width=True) | |
except Exception as e: | |
st.error(f"Diagnostic failed: {str(e)}") | |
# Professional-Grade Cleaning Tools | |
st.subheader("๐ง Enterprise Cleaning Toolkit") | |
tab1, tab2, tab3, tab4 = st.tabs(["๐งฉ Missing Data", "๐ Normalization", "๐ Outliers", "๐ Encoding"]) | |
with tab1: | |
cols = st.columns([1,3]) | |
with cols[0]: | |
imp_method = st.selectbox("Imputation Strategy", | |
["ML Impute (Iterative)", "KNN", "MICE", "Matrix Factorization"], | |
help="Select advanced imputation technique") | |
if imp_method == "KNN": | |
n_neighbors = st.slider("Neighbors", 3, 15, 5, help="Number of similar records to consider") | |
with cols[1]: | |
if st.button("Execute Smart Imputation", type="primary"): | |
with st.spinner(f"โ๏ธ Running {imp_method}..."): | |
# Advanced imputation logic | |
numeric_cols = df.select_dtypes(include=np.number).columns | |
if imp_method == "KNN": | |
imputer = KNNImputer(n_neighbors=n_neighbors) | |
df[numeric_cols] = imputer.fit_transform(df[numeric_cols]) | |
else: | |
df[numeric_cols] = df[numeric_cols].fillna(df[numeric_cols].median()) | |
st.session_state.cleaned_df = df | |
st.toast("Imputation complete!", icon="โ ") | |
with tab2: | |
cols = st.columns([1,3]) | |
with cols[0]: | |
scale_method = st.selectbox("Scaling Algorithm", | |
["Robust Scaling", "Quantum Normalization", | |
"Adaptive MinMax", "Power Transform"], | |
index=0) | |
if scale_method == "Power Transform": | |
lambda_val = st.slider("Lambda Parameter", -3.0, 3.0, 0.0) | |
with cols[1]: | |
if st.button("Apply Feature Engineering", type="primary"): | |
with st.spinner("Transforming features..."): | |
# Advanced scaling logic | |
numeric_cols = df.select_dtypes(include=np.number).columns | |
if scale_method == "Robust Scaling": | |
scaler = RobustScaler() | |
df[numeric_cols] = scaler.fit_transform(df[numeric_cols]) | |
st.session_state.cleaned_df = df | |
st.toast("Features transformed!", icon="โ ") | |
# Real-time Data Diff Viewer | |
st.subheader("๐ Version Comparison") | |
cols = st.columns(2) | |
with cols[0]: | |
st.write("Original Data Snapshot") | |
st.dataframe(st.session_state.df.head(3).style.highlight_null(color='#ff2b2b')) | |
with cols[1]: | |
st.write("Processed Version") | |
st.dataframe(df.head(3).style.highlight_null(color='#00ff87')) | |
# ================== ๐น EDA SECTION ================== | |
elif choice == "EDA": | |
st.header("๐ Advanced Exploratory Data Analysis") | |
if st.session_state.cleaned_df is not None: | |
df = st.session_state.cleaned_df | |
# ================== ๐น USER INPUTS ================== | |
st.subheader("๐ Select Analysis Type") | |
analysis_type = st.radio( | |
"Choose Analysis Type", | |
["Single Variable", "Multi-Variable", "3D Analysis"], | |
horizontal=True, | |
help="Select the type of analysis you want to perform" | |
) | |
# Dynamic Column Selection Based on Analysis Type | |
if analysis_type == "Single Variable": | |
selected_columns = st.multiselect( | |
"Select Columns for Analysis", | |
df.columns, | |
default=df.columns[:1], | |
help="Choose one or more columns for single-variable analysis" | |
) | |
chart_type = st.selectbox( | |
"Select Chart Type", | |
["Auto-Detect", "Histogram", "Box Plot", "Violin Plot"] | |
) | |
elif analysis_type == "Multi-Variable": | |
selected_columns = st.multiselect( | |
"Select Columns for Analysis", | |
df.columns, | |
default=df.columns[:2], | |
help="Choose two or more columns for multi-variable analysis" | |
) | |
chart_type = st.selectbox( | |
"Select Chart Type", | |
["Auto-Detect", "Scatter Plot", "Heatmap", "Box Plot", "Violin Plot"] | |
) | |
else: # 3D Analysis | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
x_col = st.selectbox("X Axis", df.columns) | |
with col2: | |
y_col = st.selectbox("Y Axis", df.columns) | |
with col3: | |
z_col = st.selectbox("Z Axis", df.columns) | |
chart_type = "3D Scatter" | |
# ================== ๐น AUTO-PLOT BUTTON ================== | |
if st.button("โจ Generate Advanced Visualizations", type="primary"): | |
with st.spinner("๐ Generating insights..."): | |
try: | |
# Auto-Detect Logic | |
if chart_type == "Auto-Detect": | |
if analysis_type == "Single Variable": | |
if pd.api.types.is_numeric_dtype(df[selected_columns[0]]): | |
chart_type = "Histogram" | |
else: | |
chart_type = "Bar Chart" | |
elif analysis_type == "Multi-Variable": | |
if all(pd.api.types.is_numeric_dtype(df[col]) for col in selected_columns[:2]): | |
chart_type = "Scatter Plot" | |
else: | |
chart_type = "Box Plot" | |
# Generate Visualization | |
if analysis_type == "Single Variable": | |
col = selected_columns[0] | |
fig = generate_chart(df, chart_type, col) | |
stats = calculate_statistics(df, col) | |
# Display results | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.plotly_chart(fig, use_container_width=True) | |
with col2: | |
st.subheader("๐ Key Insights") | |
if pd.api.types.is_numeric_dtype(df[col]): | |
st.metric("Mean", f"{stats['mean']:.2f}") | |
st.metric("Median", f"{stats['median']:.2f}") | |
st.metric("Std Dev", f"{stats['std']:.2f}") | |
else: | |
st.metric("Unique Values", stats['unique_values']) | |
st.metric("Most Common", stats['most_common']) | |
elif analysis_type == "Multi-Variable": | |
if len(selected_columns) < 2: | |
st.warning("Please select at least two columns") | |
else: | |
fig = generate_chart(df, chart_type, selected_columns[0], selected_columns[1]) | |
st.plotly_chart(fig, use_container_width=True) | |
# Correlation insights | |
if chart_type in ["Scatter Plot", "Heatmap"]: | |
st.subheader("๐ Correlation Insights") | |
try: | |
corr = df[selected_columns[0]].corr(df[selected_columns[1]]) | |
st.write(f"**Correlation Coefficient:** {corr:.2f}") | |
st.progress(abs(corr)) | |
st.caption("Absolute correlation strength") | |
except: | |
st.warning("Could not calculate correlation for selected columns") | |
elif analysis_type == "3D Analysis": | |
fig = generate_chart(df, "3D Scatter", x_col, y_col, z_col) | |
st.plotly_chart(fig, use_container_width=True) | |
# 3D Analysis Insights | |
st.subheader("๐ 3D Analysis Insights") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("X Range", f"{df[x_col].min():.2f} - {df[x_col].max():.2f}") | |
with col2: | |
st.metric("Y Range", f"{df[y_col].min():.2f} - {df[y_col].max():.2f}") | |
with col3: | |
st.metric("Z Range", f"{df[z_col].min():.2f} - {df[z_col].max():.2f}") | |
except Exception as e: | |
st.error(f"Visualization error: {str(e)}") | |
# ================== ๐น PRODUCTION-GRADE ML SECTION ================== | |
elif choice == "Machine Learning": | |
st.header("๐ค Enterprise ML Studio") | |
if st.session_state.cleaned_df is not None: | |
df = st.session_state.cleaned_df | |
# Model Factory | |
st.subheader("๐ญ Model Orchestration") | |
tabs = st.tabs(["AutoML", "Custom Training", "Model Registry"]) | |
with tabs[0]: | |
if st.button("Launch Hyperparameter Optimization", type="primary"): | |
with st.spinner("โก Training 25 model variants..."): | |
try: | |
target = st.selectbox("Target Variable", df.columns) | |
setup(df, target=target, session_id=42, | |
feature_interaction=True, | |
polynomial_features=True) | |
best_model = compare_models(n_select=3) | |
# Visual Leaderboard | |
results = pull() | |
fig = px.bar(results, x='Model', y=['Accuracy', 'AUC'], | |
barmode='group', template="plotly_dark", | |
title="Model Performance Leaderboard") | |
st.plotly_chart(fig, use_container_width=True) | |
except Exception as e: | |
st.error(f"AutoML failed: {str(e)}") | |
# ================== ๐น PREDICTIONS PAGE COMPLETION ================== | |
elif choice == "Predictions": | |
st.title("๐ฎ Make Predictions on New Data") | |
if st.session_state.get("model"): | |
uploaded_file = st.file_uploader("Upload New Data for Prediction", type=["csv", "xlsx"]) | |
if uploaded_file: | |
new_data = pd.read_csv(uploaded_file) if uploaded_file.name.endswith('.csv') else pd.read_excel(uploaded_file) | |
st.write("๐ Preview of New Data:") | |
st.dataframe(new_data.head()) | |
try: | |
predictions = st.session_state.model.predict(new_data) | |
proba = st.session_state.model.predict_proba(new_data) if hasattr(st.session_state.model, 'predict_proba') else None | |
st.subheader("๐ข Predictions:") | |
result_df = pd.DataFrame({ | |
'Prediction': predictions, | |
'Confidence': proba.max(axis=1) if proba is not None else [1.0]*len(predictions) | |
}) | |
st.dataframe(result_df.style.background_gradient(cmap='Blues')) | |
# Download predictions | |
csv = result_df.to_csv(index=False).encode('utf-8') | |
st.download_button( | |
label="๐ฅ Download Predictions", | |
data=csv, | |
file_name='predictions.csv', | |
mime='text/csv' | |
) | |
except Exception as e: | |
st.error(f"Prediction error: {str(e)}") | |
else: | |
st.warning("โ ๏ธ No trained model found. Please train a model first.") | |
# ================== ๐น VISUALIZATION PAGE COMPLETION ================== | |
# ================== ๐น VISUALIZATION PAGE COMPLETION ================== | |
elif choice == "Visualization": | |
st.header("๐ Advanced Visualization Lab") | |
if st.session_state.cleaned_df is not None: | |
df = st.session_state.cleaned_df | |
# Smart Visualization Assistant | |
col1, col2 = st.columns([1, 3]) | |
with col1: | |
if st.button("โจ Suggest Visualizations", help="Generate smart visualization recommendations"): | |
with st.spinner("๐จ Generating recommendations..."): | |
try: | |
numeric_cols = df.select_dtypes(include=np.number).columns.tolist() | |
cat_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() | |
# Auto-detect visualization types | |
if len(numeric_cols) >= 3: | |
st.session_state.viz_type = "3D Scatter" | |
elif len(cat_cols) > 0: | |
st.session_state.viz_type = "Pie" | |
else: | |
st.session_state.viz_type = "Histogram" | |
st.success(f"Recommended visualization type: {st.session_state.viz_type}") | |
except Exception as e: | |
st.error(f"Recommendation failed: {str(e)}") | |
# Manual Visualization Controls | |
with st.expander("๐จ Custom Visualization", expanded=True): | |
plot_options = ["3D Scatter", "Line", "Bar", "Pie", "Histogram", "Box", "Violin", "Heatmap"] | |
plot_type = st.selectbox("Select Plot Type", plot_options, | |
index=plot_options.index(st.session_state.viz_type) if 'viz_type' in st.session_state else 0) | |
# Dynamic Axis Selection | |
col1, col2, col3 = st.columns(3) | |
fig = None | |
# 3D Scatter Plot | |
if plot_type == "3D Scatter": | |
with col1: | |
x_axis = st.selectbox("X Axis", df.columns, index=0) | |
with col2: | |
y_axis = st.selectbox("Y Axis", df.columns, index=min(1, len(df.columns)-1)) | |
with col3: | |
z_axis = st.selectbox("Z Axis", df.columns, index=min(2, len(df.columns)-1)) | |
color_by = st.selectbox("Color By", [None] + df.columns.tolist()) | |
fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=color_by, | |
color_continuous_scale=px.colors.cyclical.IceFire) | |
# Line Chart | |
elif plot_type == "Line": | |
with col1: | |
x_axis = st.selectbox("X Axis", df.columns, index=0) | |
with col2: | |
y_axis = st.selectbox("Y Axis", df.select_dtypes(include=np.number).columns.tolist()) | |
with col3: | |
color_by = st.selectbox("Group By", [None] + df.columns.tolist()) | |
fig = px.line(df, x=x_axis, y=y_axis, color=color_by, | |
line_group=color_by if color_by else None) | |
# Bar Chart | |
elif plot_type == "Bar": | |
with col1: | |
x_axis = st.selectbox("X Axis", df.columns, index=0) | |
with col2: | |
y_axis = st.selectbox("Y Axis", df.select_dtypes(include=np.number).columns.tolist()) | |
with col3: | |
color_by = st.selectbox("Color By", [None] + df.columns.tolist()) | |
fig = px.bar(df, x=x_axis, y=y_axis, color=color_by, barmode='group') | |
# Pie Chart | |
elif plot_type == "Pie": | |
with col1: | |
names = st.selectbox("Categories", df.select_dtypes(include=['object', 'category']).columns.tolist()) | |
with col2: | |
values = st.selectbox("Values", df.select_dtypes(include=np.number).columns.tolist()) | |
fig = px.pie(df, names=names, values=values, hole=0.3) | |
# Histogram | |
elif plot_type == "Histogram": | |
with col1: | |
num_col = st.selectbox("Numerical Column", df.select_dtypes(include=np.number).columns.tolist()) | |
with col2: | |
color_by = st.selectbox("Split By", [None] + df.columns.tolist()) | |
fig = px.histogram(df, x=num_col, color=color_by, marginal="rug", | |
nbins=st.slider("Number of Bins", 5, 100, 20)) | |
# Box Plot | |
elif plot_type == "Box": | |
with col1: | |
y_axis = st.selectbox("Y Axis", df.select_dtypes(include=np.number).columns.tolist()) | |
with col2: | |
x_axis = st.selectbox("X Axis (Optional)", [None] + df.columns.tolist()) | |
fig = px.box(df, x=x_axis, y=y_axis, color=x_axis) | |
# Violin Plot | |
elif plot_type == "Violin": | |
with col1: | |
y_axis = st.selectbox("Y Axis", df.select_dtypes(include=np.number).columns.tolist()) | |
with col2: | |
x_axis = st.selectbox("X Axis (Optional)", [None] + df.columns.tolist()) | |
fig = px.violin(df, x=x_axis, y=y_axis, color=x_axis, box=True) | |
# Heatmap | |
elif plot_type == "Heatmap": | |
numeric_cols = df.select_dtypes(include=np.number).columns.tolist() | |
selected_cols = st.multiselect("Select Numerical Columns", numeric_cols, default=numeric_cols[:5]) | |
if len(selected_cols) >= 2: | |
corr_matrix = df[selected_cols].corr() | |
fig = px.imshow(corr_matrix, text_auto=True, | |
color_continuous_scale=px.colors.diverging.RdBu_r) | |
else: | |
st.warning("Select at least 2 numerical columns for heatmap") | |
# Plot Customization | |
if fig: | |
with st.expander("๐ญ Style Customization"): | |
col1, col2 = st.columns(2) | |
with col1: | |
color_theme = st.selectbox("Color Theme", px.colors.named_colorscales(), | |
index=px.colors.named_colorscales().index('Viridis')) | |
fig.update_layout(colorway=px.colors.sequential[color_theme]) | |
with col2: | |
fig.update_layout( | |
template=st.selectbox("Theme Style", ["plotly", "plotly_dark", "ggplot2", "seaborn"]), | |
font_size=st.slider("Font Size", 10, 24, 14) | |
) | |
# Display Plot | |
st.plotly_chart(fig, use_container_width=True) | |
# Download Button | |
plot_html = fig.to_html() | |
st.download_button( | |
label="๐ฅ Download Plot", | |
data=plot_html, | |
file_name=f"{plot_type.replace(' ', '_')}_plot.html", | |
mime="text/html" | |
) |