Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
from data_analysis import * | |
import numpy as np | |
import pickle | |
import streamlit as st | |
from utilities import set_header, load_local_css, update_db, project_selection | |
from post_gres_cred import db_cred | |
from utilities import update_db | |
import re | |
st.set_page_config( | |
page_title="Data Assessment", | |
page_icon=":shark:", | |
layout="wide", | |
initial_sidebar_state="collapsed", | |
) | |
schema = db_cred["schema"] | |
load_local_css("styles.css") | |
set_header() | |
if "username" not in st.session_state: | |
st.session_state["username"] = None | |
if "project_name" not in st.session_state: | |
st.session_state["project_name"] = None | |
if "project_dct" not in st.session_state: | |
project_selection() | |
st.stop() | |
if "username" in st.session_state and st.session_state["username"] is not None: | |
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None: | |
st.error(f"Please import data from the Data Import Page") | |
st.stop() | |
st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][ | |
"imputed_tool_df" | |
] | |
st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][ | |
"category_dict" | |
] | |
# st.write(st.session_state['category_dict']) | |
cols1 = st.columns([2, 1]) | |
with cols1[0]: | |
st.markdown(f"**Welcome {st.session_state['username']}**") | |
with cols1[1]: | |
st.markdown(f"**Current Project: {st.session_state['project_name']}**") | |
st.title("Data Assessment") | |
target_variables = [ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Response Metrics" | |
] | |
def format_display(inp): | |
return ( | |
inp.title() | |
.replace("_", " ") | |
.replace("Media", "") | |
.replace("Cnt", "") | |
.strip() | |
) | |
target_variables = list(*target_variables) | |
target_column = st.selectbox( | |
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)", | |
target_variables, | |
index=st.session_state["project_dct"]["data_validation"]["target_column"], | |
format_func=format_display, | |
) | |
st.session_state["project_dct"]["data_validation"]["target_column"] = ( | |
target_variables.index(target_column) | |
) | |
st.session_state["target_column"] = target_column | |
if "panel" not in st.session_state["cleaned_data"].columns: | |
st.write('True') | |
st.session_state["cleaned_data"]["panel"] = ["Aggregated"] * len( | |
st.session_state["cleaned_data"] | |
) | |
disable = True | |
else: | |
panels = st.session_state["cleaned_data"]["panel"] | |
disable = False | |
selected_panels = st.multiselect( | |
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.", | |
st.session_state["cleaned_data"]["panel"].unique(), | |
default=st.session_state["project_dct"]["data_validation"]["selected_panels"], | |
disabled=disable, | |
) | |
st.session_state["project_dct"]["data_validation"][ | |
"selected_panels" | |
] = selected_panels | |
aggregation_dict = { | |
item: "sum" if key == "Media" else "mean" | |
for key, value in st.session_state["category_dict"].items() | |
for item in value | |
if item not in ["date", "panel"] | |
} | |
aggregation_dict = { | |
key: value | |
for key, value in aggregation_dict.items() | |
if key in st.session_state["cleaned_data"].columns | |
} | |
with st.expander("**Target Variable Analysis**"): | |
if len(selected_panels) > 0: | |
st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][ | |
st.session_state["cleaned_data"]["panel"].isin(selected_panels) | |
] | |
st.session_state["Cleaned_data_panel"] = ( | |
st.session_state["Cleaned_data_panel"] | |
.groupby(by="date") | |
.agg(aggregation_dict) | |
) | |
st.session_state["Cleaned_data_panel"] = st.session_state[ | |
"Cleaned_data_panel" | |
].reset_index() | |
else: | |
# st.write(st.session_state['cleaned_data']) | |
st.session_state["Cleaned_data_panel"] = ( | |
st.session_state["cleaned_data"] | |
.groupby(by="date") | |
.agg(aggregation_dict) | |
) | |
st.session_state["Cleaned_data_panel"] = st.session_state[ | |
"Cleaned_data_panel" | |
].reset_index() | |
fig = line_plot_target( | |
st.session_state["Cleaned_data_panel"], | |
target=target_column, | |
title=f"{target_column} Over Time", | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
media_channel = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Media" | |
] | |
) | |
spends_features = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Spends" | |
] | |
) | |
# st.write(media_channel) | |
exo_var = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Exogenous" | |
] | |
) | |
internal_var = list( | |
*[ | |
st.session_state["category_dict"][key] | |
for key in st.session_state["category_dict"].keys() | |
if key == "Internal" | |
] | |
) | |
Non_media_variables = exo_var + internal_var | |
st.markdown("### Annual Data Summary") | |
summary_df = summary( | |
st.session_state["Cleaned_data_panel"], | |
media_channel + [target_column] + spends_features, | |
spends=None, | |
Target=True, | |
) | |
st.dataframe( | |
summary_df.sort_index(axis=1), | |
use_container_width=True, | |
) | |
if st.checkbox("View Raw Data"): | |
st.cache_resource(show_spinner=False) | |
def raw_df_gen(): | |
# Convert 'date' to datetime but do not convert to string yet for sorting | |
dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"]) | |
# Concatenate the dates with other numeric columns formatted | |
raw_df = pd.concat( | |
[ | |
dates, | |
st.session_state["Cleaned_data_panel"] | |
.select_dtypes(np.number) | |
.applymap(format_numbers), | |
], | |
axis=1, | |
) | |
# Now sort raw_df by the 'date' column, which is still in datetime format | |
sorted_raw_df = raw_df.sort_values(by="date", ascending=True) | |
# After sorting, convert 'date' to string format for display | |
sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y") | |
return sorted_raw_df | |
# Display the sorted DataFrame in Streamlit | |
st.dataframe(raw_df_gen()) | |
col1 = st.columns(1) | |
if "selected_feature" not in st.session_state: | |
st.session_state["selected_feature"] = None | |
# st.warning('Work in Progress') | |
with st.expander("Media Variables Analysis"): | |
# Get the selected feature | |
st.session_state["selected_feature"] = st.selectbox( | |
"Select Media", media_channel + spends_features, format_func=format_display | |
) | |
# st.write(st.session_state["selected_feature"].split('cnt_')[1] ) | |
# st.session_state["project_dct"]["data_validation"]["selected_feature"] = ( | |
# ) | |
# Filter spends features based on the selected feature | |
spends_col = st.columns(2) | |
spends_feature = [ | |
col | |
for col in spends_features | |
if re.split(r"cost_|spends_", col.lower())[1] | |
in st.session_state["selected_feature"] | |
] | |
with spends_col[0]: | |
if len(spends_feature) == 0: | |
st.warning( | |
"The selected metric does not include a 'spends' variable in the data. Please verify that the columns are correctly named or select the appropriate columns in the provided selection box." | |
) | |
else: | |
st.write( | |
f'Selected "{spends_feature[0]}" as the corresponding spends variable automatically. If this is incorrect, please click the checkbox to change the variable.' | |
) | |
with spends_col[1]: | |
if len(spends_feature) == 0 or st.checkbox( | |
'Select "Spends" variable for CPM and CPC calculation' | |
): | |
spends_feature = [st.selectbox("Spends Variable", spends_features)] | |
if "validation" not in st.session_state: | |
st.session_state["validation"] = st.session_state["project_dct"][ | |
"data_validation" | |
]["validated_variables"] | |
val_variables = [col for col in media_channel if col != "date"] | |
if not set( | |
st.session_state["project_dct"]["data_validation"]["validated_variables"] | |
).issubset(set(val_variables)): | |
st.session_state["validation"] = [] | |
else: | |
fig_row1 = line_plot( | |
st.session_state["Cleaned_data_panel"], | |
x_col="date", | |
y1_cols=[st.session_state["selected_feature"]], | |
y2_cols=[target_column], | |
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time', | |
) | |
st.plotly_chart(fig_row1, use_container_width=True) | |
st.markdown("### Summary") | |
st.dataframe( | |
summary( | |
st.session_state["Cleaned_data_panel"], | |
[st.session_state["selected_feature"]], | |
spends=spends_feature[0], | |
), | |
use_container_width=True, | |
) | |
cols2 = st.columns(2) | |
if len( | |
set(st.session_state["validation"]).intersection(val_variables) | |
) == len(val_variables): | |
disable = True | |
help = "All media variables are validated" | |
else: | |
disable = False | |
help = "" | |
with cols2[0]: | |
if st.button("Validate", disabled=disable, help=help): | |
st.session_state["validation"].append( | |
st.session_state["selected_feature"] | |
) | |
with cols2[1]: | |
if st.checkbox("Validate All", disabled=disable, help=help): | |
st.session_state["validation"].extend(val_variables) | |
st.success("All media variables are validated ✅") | |
if len( | |
set(st.session_state["validation"]).intersection(val_variables) | |
) != len(val_variables): | |
validation_data = pd.DataFrame( | |
{ | |
"Validate": [ | |
(True if col in st.session_state["validation"] else False) | |
for col in val_variables | |
], | |
"Variables": val_variables, | |
} | |
) | |
sorted_validation_df = validation_data.sort_values( | |
by="Variables", ascending=True, na_position="first" | |
) | |
cols3 = st.columns([1, 30]) | |
with cols3[1]: | |
validation_df = st.data_editor( | |
sorted_validation_df, | |
# column_config={ | |
# 'Validate':st.column_config.CheckboxColumn(wi) | |
# }, | |
column_config={ | |
"Validate": st.column_config.CheckboxColumn( | |
default=False, | |
width=100, | |
), | |
"Variables": st.column_config.TextColumn(width=1000), | |
}, | |
hide_index=True, | |
) | |
selected_rows = validation_df[validation_df["Validate"] == True][ | |
"Variables" | |
] | |
# st.write(selected_rows) | |
st.session_state["validation"].extend(selected_rows) | |
st.session_state["project_dct"]["data_validation"][ | |
"validated_variables" | |
] = st.session_state["validation"] | |
not_validated_variables = [ | |
col | |
for col in val_variables | |
if col not in st.session_state["validation"] | |
] | |
if not_validated_variables: | |
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}' | |
st.warning(not_validated_message) | |
with st.expander("Non-Media Variables Analysis"): | |
if len(Non_media_variables) == 0: | |
st.warning("Non-Media variables not present") | |
else: | |
selected_columns_row4 = st.selectbox( | |
"Select Channel", | |
Non_media_variables, | |
format_func=format_display, | |
index=st.session_state["project_dct"]["data_validation"][ | |
"Non_media_variables" | |
], | |
) | |
st.session_state["project_dct"]["data_validation"][ | |
"Non_media_variables" | |
] = Non_media_variables.index(selected_columns_row4) | |
# # Create the dual-axis line plot | |
fig_row4 = line_plot( | |
st.session_state["Cleaned_data_panel"], | |
x_col="date", | |
y1_cols=[selected_columns_row4], | |
y2_cols=[target_column], | |
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time", | |
) | |
st.plotly_chart(fig_row4, use_container_width=True) | |
selected_non_media = selected_columns_row4 | |
sum_df = st.session_state["Cleaned_data_panel"][ | |
["date", selected_non_media, target_column] | |
] | |
sum_df["Year"] = pd.to_datetime( | |
st.session_state["Cleaned_data_panel"]["date"] | |
).dt.year | |
# st.dataframe(df) | |
# st.dataframe(sum_df.head(2)) | |
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum") | |
sum_df.loc["Grand Total"] = sum_df.sum() | |
sum_df = sum_df.applymap(format_numbers) | |
sum_df.fillna("-", inplace=True) | |
sum_df = sum_df.replace({"0.0": "-", "nan": "-"}) | |
st.markdown("### Summary") | |
st.dataframe(sum_df, use_container_width=True) | |
with st.expander("Correlation Analysis"): | |
options = list( | |
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns | |
) | |
if "correlation" not in st.session_state["project_dct"]["data_import"]: | |
st.session_state["project_dct"]["data_import"]["correlation"]=[] | |
selected_options = st.multiselect( | |
"Select Variables for Correlation Plot", | |
[var for var in options if var != target_column], | |
default=st.session_state["project_dct"]["data_import"]["correlation"], | |
) | |
st.session_state["project_dct"]["data_import"]["correlation"] = selected_options | |
st.pyplot( | |
correlation_plot( | |
st.session_state["Cleaned_data_panel"], | |
selected_options, | |
target_column, | |
) | |
) | |
if st.button("Save Changes", use_container_width=True): | |
# Update DB | |
update_db( | |
prj_id=st.session_state["project_number"], | |
page_nam="Data Validation and Insights", | |
file_nam="project_dct", | |
pkl_obj=pickle.dumps(st.session_state["project_dct"]), | |
schema=schema, | |
) | |
st.success("Changes saved") | |