Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,54 +3,27 @@ import pandas as pd
|
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
import plotly.graph_objects as go
|
6 |
-
import matplotlib.pyplot as plt
|
7 |
from scipy.stats import pearsonr, spearmanr
|
8 |
from sklearn.inspection import permutation_importance
|
9 |
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
10 |
-
from sklearn.model_selection import train_test_split
|
11 |
-
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
12 |
-
from sklearn.
|
13 |
-
from
|
14 |
-
from
|
15 |
-
import joblib
|
|
|
16 |
import shap
|
17 |
from datetime import datetime
|
|
|
18 |
|
19 |
-
#
|
20 |
-
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
-
page_icon="🔮",
|
25 |
-
layout="wide",
|
26 |
-
initial_sidebar_state="expanded"
|
27 |
-
)
|
28 |
-
|
29 |
|
30 |
-
# --------------------------
|
31 |
-
# Custom Styling
|
32 |
-
# --------------------------
|
33 |
-
st.markdown("""
|
34 |
-
<style>
|
35 |
-
.main {background-color: #f8f9fa;}
|
36 |
-
.sidebar .sidebar-content {background-color: #2c3e50;}
|
37 |
-
.stButton>button {background-color: #3498db; color: white;}
|
38 |
-
.stTextInput>div>div>input {border: 1px solid #3498db;}
|
39 |
-
.stSelectbox>div>div>select {border: 1px solid #3498db;}
|
40 |
-
.stSlider>div>div>div>div {background-color: #3498db;}
|
41 |
-
.metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
42 |
-
</style>
|
43 |
-
""", unsafe_allow_html=True)
|
44 |
-
|
45 |
-
# --------------------------
|
46 |
-
# Session State Initialization
|
47 |
-
# --------------------------
|
48 |
-
if 'raw_data' not in st.session_state:
|
49 |
-
st.session_state.raw_data = None
|
50 |
-
if 'cleaned_data' not in st.session_state:
|
51 |
-
st.session_state.cleaned_data = None
|
52 |
-
if 'model' not in st.session_state:
|
53 |
-
st.session_state.model = None
|
54 |
|
55 |
# --------------------------
|
56 |
# Helper Functions
|
@@ -92,7 +65,6 @@ def generate_quality_report(df):
|
|
92 |
report['columns'][col] = col_report
|
93 |
return report
|
94 |
|
95 |
-
# Function to train the model (Separated for clarity and reusability)
|
96 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
97 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
98 |
|
@@ -258,13 +230,12 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
258 |
# Store the column order for prediction purposes
|
259 |
column_order = X.columns
|
260 |
|
261 |
-
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance
|
262 |
|
263 |
except Exception as e:
|
264 |
st.error(f"Training failed: {str(e)}")
|
265 |
-
return None, None, None, None, None, None, None
|
266 |
-
|
267 |
-
# Model Validation Function
|
268 |
def validate_model(model_path, df, target, features, test_size):
|
269 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
270 |
try:
|
@@ -365,18 +336,18 @@ with st.sidebar:
|
|
365 |
# --------------------------
|
366 |
if app_mode == "Data Upload":
|
367 |
st.title("📤 Data Upload & Profiling")
|
368 |
-
|
369 |
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
370 |
-
|
371 |
if uploaded_file:
|
372 |
try:
|
373 |
if uploaded_file.name.endswith('.csv'):
|
374 |
df = pd.read_csv(uploaded_file)
|
375 |
else:
|
376 |
df = pd.read_excel(uploaded_file)
|
377 |
-
|
378 |
st.session_state.raw_data = df
|
379 |
-
|
380 |
col1, col2, col3 = st.columns(3)
|
381 |
with col1:
|
382 |
st.metric("Rows", df.shape[0])
|
@@ -384,15 +355,15 @@ if app_mode == "Data Upload":
|
|
384 |
st.metric("Columns", df.shape[1])
|
385 |
with col3:
|
386 |
st.metric("Missing Values", df.isna().sum().sum())
|
387 |
-
|
388 |
with st.expander("Data Preview", expanded=True):
|
389 |
st.dataframe(df.head(10), use_container_width=True)
|
390 |
-
|
391 |
if st.button("Generate Full Profile Report"):
|
392 |
with st.spinner("Generating comprehensive analysis..."):
|
393 |
pr = ProfileReport(df, explorative=True)
|
394 |
st_profile_report(pr)
|
395 |
-
|
396 |
except Exception as e:
|
397 |
st.error(f"Error loading file: {str(e)}")
|
398 |
|
@@ -406,6 +377,8 @@ elif app_mode == "Data Cleaning":
|
|
406 |
st.warning("Please upload data first")
|
407 |
st.stop()
|
408 |
|
|
|
|
|
409 |
# Initialize session state (only if it's not already there)
|
410 |
if 'data_versions' not in st.session_state:
|
411 |
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
|
|
3 |
import numpy as np
|
4 |
import plotly.express as px
|
5 |
import plotly.graph_objects as go
|
6 |
+
import matplotlib.pyplot as plt # For SHAP charts
|
7 |
from scipy.stats import pearsonr, spearmanr
|
8 |
from sklearn.inspection import permutation_importance
|
9 |
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
10 |
+
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
|
11 |
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor
|
12 |
+
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
13 |
+
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score, confusion_matrix, classification_report
|
14 |
+
from sklearn.impute import SimpleImputer
|
15 |
+
import joblib # For saving and loading models
|
16 |
+
import os # For file directory
|
17 |
import shap
|
18 |
from datetime import datetime
|
19 |
+
from stqdm import stqdm
|
20 |
|
21 |
+
# Constants used (global)
|
22 |
+
PATH_FILES = "/".join(('.', "files"))
|
23 |
+
# Ensure upload location exists; make dir if it didn't create one.
|
24 |
+
if not os.path.isdir("..") / PATH_FILES:
|
25 |
+
os.makedirs("created", 0o777, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
# --------------------------
|
29 |
# Helper Functions
|
|
|
65 |
report['columns'][col] = col_report
|
66 |
return report
|
67 |
|
|
|
68 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
69 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
70 |
|
|
|
230 |
# Store the column order for prediction purposes
|
231 |
column_order = X.columns
|
232 |
|
233 |
+
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
|
234 |
|
235 |
except Exception as e:
|
236 |
st.error(f"Training failed: {str(e)}")
|
237 |
+
return None, None, None, None, None, None, None, None, None
|
238 |
+
|
|
|
239 |
def validate_model(model_path, df, target, features, test_size):
|
240 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
241 |
try:
|
|
|
336 |
# --------------------------
|
337 |
if app_mode == "Data Upload":
|
338 |
st.title("📤 Data Upload & Profiling")
|
339 |
+
|
340 |
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
341 |
+
|
342 |
if uploaded_file:
|
343 |
try:
|
344 |
if uploaded_file.name.endswith('.csv'):
|
345 |
df = pd.read_csv(uploaded_file)
|
346 |
else:
|
347 |
df = pd.read_excel(uploaded_file)
|
348 |
+
|
349 |
st.session_state.raw_data = df
|
350 |
+
|
351 |
col1, col2, col3 = st.columns(3)
|
352 |
with col1:
|
353 |
st.metric("Rows", df.shape[0])
|
|
|
355 |
st.metric("Columns", df.shape[1])
|
356 |
with col3:
|
357 |
st.metric("Missing Values", df.isna().sum().sum())
|
358 |
+
|
359 |
with st.expander("Data Preview", expanded=True):
|
360 |
st.dataframe(df.head(10), use_container_width=True)
|
361 |
+
|
362 |
if st.button("Generate Full Profile Report"):
|
363 |
with st.spinner("Generating comprehensive analysis..."):
|
364 |
pr = ProfileReport(df, explorative=True)
|
365 |
st_profile_report(pr)
|
366 |
+
|
367 |
except Exception as e:
|
368 |
st.error(f"Error loading file: {str(e)}")
|
369 |
|
|
|
377 |
st.warning("Please upload data first")
|
378 |
st.stop()
|
379 |
|
380 |
+
df = st.session_state.raw_data.copy() # Ensure df is defined in this section
|
381 |
+
|
382 |
# Initialize session state (only if it's not already there)
|
383 |
if 'data_versions' not in st.session_state:
|
384 |
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|