CosmickVisions commited on
Commit
a0795a8
·
verified ·
1 Parent(s): 0cf55dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -53
app.py CHANGED
@@ -3,54 +3,27 @@ import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
5
  import plotly.graph_objects as go
6
- import matplotlib.pyplot as plt #For SHAP charts
7
  from scipy.stats import pearsonr, spearmanr
8
  from sklearn.inspection import permutation_importance
9
  from sklearn.preprocessing import StandardScaler, LabelEncoder
10
- from sklearn.model_selection import train_test_split
11
- from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
12
- from sklearn.metrics import accuracy_score, mean_squared_error
13
- from ydata_profiling import ProfileReport
14
- from streamlit_pandas_profiling import st_profile_report
15
- import joblib
 
16
  import shap
17
  from datetime import datetime
 
18
 
19
- # --------------------------
20
- # Page Configuration
21
- # --------------------------
22
- st.set_page_config(
23
- page_title="DataInsight Pro",
24
- page_icon="🔮",
25
- layout="wide",
26
- initial_sidebar_state="expanded"
27
- )
28
-
29
 
30
- # --------------------------
31
- # Custom Styling
32
- # --------------------------
33
- st.markdown("""
34
- <style>
35
- .main {background-color: #f8f9fa;}
36
- .sidebar .sidebar-content {background-color: #2c3e50;}
37
- .stButton>button {background-color: #3498db; color: white;}
38
- .stTextInput>div>div>input {border: 1px solid #3498db;}
39
- .stSelectbox>div>div>select {border: 1px solid #3498db;}
40
- .stSlider>div>div>div>div {background-color: #3498db;}
41
- .metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
42
- </style>
43
- """, unsafe_allow_html=True)
44
-
45
- # --------------------------
46
- # Session State Initialization
47
- # --------------------------
48
- if 'raw_data' not in st.session_state:
49
- st.session_state.raw_data = None
50
- if 'cleaned_data' not in st.session_state:
51
- st.session_state.cleaned_data = None
52
- if 'model' not in st.session_state:
53
- st.session_state.model = None
54
 
55
  # --------------------------
56
  # Helper Functions
@@ -92,7 +65,6 @@ def generate_quality_report(df):
92
  report['columns'][col] = col_report
93
  return report
94
 
95
- # Function to train the model (Separated for clarity and reusability)
96
  def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
97
  """Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
98
 
@@ -258,13 +230,12 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
258
  # Store the column order for prediction purposes
259
  column_order = X.columns
260
 
261
- return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance
262
 
263
  except Exception as e:
264
  st.error(f"Training failed: {str(e)}")
265
- return None, None, None, None, None, None, None
266
-
267
- # Model Validation Function
268
  def validate_model(model_path, df, target, features, test_size):
269
  """Loads a model, preprocesses data, and evaluates the model on a validation set."""
270
  try:
@@ -365,18 +336,18 @@ with st.sidebar:
365
  # --------------------------
366
  if app_mode == "Data Upload":
367
  st.title("📤 Data Upload & Profiling")
368
-
369
  uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
370
-
371
  if uploaded_file:
372
  try:
373
  if uploaded_file.name.endswith('.csv'):
374
  df = pd.read_csv(uploaded_file)
375
  else:
376
  df = pd.read_excel(uploaded_file)
377
-
378
  st.session_state.raw_data = df
379
-
380
  col1, col2, col3 = st.columns(3)
381
  with col1:
382
  st.metric("Rows", df.shape[0])
@@ -384,15 +355,15 @@ if app_mode == "Data Upload":
384
  st.metric("Columns", df.shape[1])
385
  with col3:
386
  st.metric("Missing Values", df.isna().sum().sum())
387
-
388
  with st.expander("Data Preview", expanded=True):
389
  st.dataframe(df.head(10), use_container_width=True)
390
-
391
  if st.button("Generate Full Profile Report"):
392
  with st.spinner("Generating comprehensive analysis..."):
393
  pr = ProfileReport(df, explorative=True)
394
  st_profile_report(pr)
395
-
396
  except Exception as e:
397
  st.error(f"Error loading file: {str(e)}")
398
 
@@ -406,6 +377,8 @@ elif app_mode == "Data Cleaning":
406
  st.warning("Please upload data first")
407
  st.stop()
408
 
 
 
409
  # Initialize session state (only if it's not already there)
410
  if 'data_versions' not in st.session_state:
411
  st.session_state.data_versions = [st.session_state.raw_data.copy()]
 
3
  import numpy as np
4
  import plotly.express as px
5
  import plotly.graph_objects as go
6
+ import matplotlib.pyplot as plt # For SHAP charts
7
  from scipy.stats import pearsonr, spearmanr
8
  from sklearn.inspection import permutation_importance
9
  from sklearn.preprocessing import StandardScaler, LabelEncoder
10
+ from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
11
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor
12
+ from sklearn.neural_network import MLPClassifier, MLPRegressor
13
+ from sklearn.metrics import accuracy_score, mean_squared_error, r2_score, confusion_matrix, classification_report
14
+ from sklearn.impute import SimpleImputer
15
+ import joblib # For saving and loading models
16
+ import os # For file directory
17
  import shap
18
  from datetime import datetime
19
+ from stqdm import stqdm
20
 
21
+ # Constants used (global)
22
+ PATH_FILES = "/".join(('.', "files"))
23
+ # Ensure upload location exists; make dir if it didn't create one.
24
+ if not os.path.isdir("..") / PATH_FILES:
25
+ os.makedirs("created", 0o777, exist_ok=True)
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # --------------------------
29
  # Helper Functions
 
65
  report['columns'][col] = col_report
66
  return report
67
 
 
68
  def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
69
  """Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
70
 
 
230
  # Store the column order for prediction purposes
231
  column_order = X.columns
232
 
233
+ return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
234
 
235
  except Exception as e:
236
  st.error(f"Training failed: {str(e)}")
237
+ return None, None, None, None, None, None, None, None, None
238
+
 
239
  def validate_model(model_path, df, target, features, test_size):
240
  """Loads a model, preprocesses data, and evaluates the model on a validation set."""
241
  try:
 
336
  # --------------------------
337
  if app_mode == "Data Upload":
338
  st.title("📤 Data Upload & Profiling")
339
+
340
  uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
341
+
342
  if uploaded_file:
343
  try:
344
  if uploaded_file.name.endswith('.csv'):
345
  df = pd.read_csv(uploaded_file)
346
  else:
347
  df = pd.read_excel(uploaded_file)
348
+
349
  st.session_state.raw_data = df
350
+
351
  col1, col2, col3 = st.columns(3)
352
  with col1:
353
  st.metric("Rows", df.shape[0])
 
355
  st.metric("Columns", df.shape[1])
356
  with col3:
357
  st.metric("Missing Values", df.isna().sum().sum())
358
+
359
  with st.expander("Data Preview", expanded=True):
360
  st.dataframe(df.head(10), use_container_width=True)
361
+
362
  if st.button("Generate Full Profile Report"):
363
  with st.spinner("Generating comprehensive analysis..."):
364
  pr = ProfileReport(df, explorative=True)
365
  st_profile_report(pr)
366
+
367
  except Exception as e:
368
  st.error(f"Error loading file: {str(e)}")
369
 
 
377
  st.warning("Please upload data first")
378
  st.stop()
379
 
380
+ df = st.session_state.raw_data.copy() # Ensure df is defined in this section
381
+
382
  # Initialize session state (only if it's not already there)
383
  if 'data_versions' not in st.session_state:
384
  st.session_state.data_versions = [st.session_state.raw_data.copy()]