MediaMixOptimization / pages /2_Data_Assessment.py
samkeet's picture
Upload 40 files
00b00eb verified
import streamlit as st
import pandas as pd
from data_analysis import *
import numpy as np
import pickle
import streamlit as st
from utilities import set_header, load_local_css, update_db, project_selection
from post_gres_cred import db_cred
from utilities import update_db
import re
st.set_page_config(
page_title="Data Assessment​",
page_icon=":shark:",
layout="wide",
initial_sidebar_state="collapsed",
)
schema = db_cred["schema"]
load_local_css("styles.css")
set_header()
if "username" not in st.session_state:
st.session_state["username"] = None
if "project_name" not in st.session_state:
st.session_state["project_name"] = None
if "project_dct" not in st.session_state:
project_selection()
st.stop()
if "username" in st.session_state and st.session_state["username"] is not None:
if st.session_state["project_dct"]["data_import"]["imputed_tool_df"] is None:
st.error(f"Please import data from the Data Import Page")
st.stop()
st.session_state["cleaned_data"] = st.session_state["project_dct"]["data_import"][
"imputed_tool_df"
]
st.session_state["category_dict"] = st.session_state["project_dct"]["data_import"][
"category_dict"
]
# st.write(st.session_state['category_dict'])
cols1 = st.columns([2, 1])
with cols1[0]:
st.markdown(f"**Welcome {st.session_state['username']}**")
with cols1[1]:
st.markdown(f"**Current Project: {st.session_state['project_name']}**")
st.title("Data Assessment​")
target_variables = [
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Response Metrics"
]
def format_display(inp):
return (
inp.title()
.replace("_", " ")
.replace("Media", "")
.replace("Cnt", "")
.strip()
)
target_variables = list(*target_variables)
target_column = st.selectbox(
"Select the Target Feature/Dependent Variable (will be used in all charts as reference)",
target_variables,
index=st.session_state["project_dct"]["data_validation"]["target_column"],
format_func=format_display,
)
st.session_state["project_dct"]["data_validation"]["target_column"] = (
target_variables.index(target_column)
)
st.session_state["target_column"] = target_column
if "panel" not in st.session_state["cleaned_data"].columns:
st.write('True')
st.session_state["cleaned_data"]["panel"] = ["Aggregated"] * len(
st.session_state["cleaned_data"]
)
disable = True
else:
panels = st.session_state["cleaned_data"]["panel"]
disable = False
selected_panels = st.multiselect(
"Please choose the panels you wish to analyze.If no panels are selected, insights will be derived from the overall data.",
st.session_state["cleaned_data"]["panel"].unique(),
default=st.session_state["project_dct"]["data_validation"]["selected_panels"],
disabled=disable,
)
st.session_state["project_dct"]["data_validation"][
"selected_panels"
] = selected_panels
aggregation_dict = {
item: "sum" if key == "Media" else "mean"
for key, value in st.session_state["category_dict"].items()
for item in value
if item not in ["date", "panel"]
}
aggregation_dict = {
key: value
for key, value in aggregation_dict.items()
if key in st.session_state["cleaned_data"].columns
}
with st.expander("**Target Variable Analysis**"):
if len(selected_panels) > 0:
st.session_state["Cleaned_data_panel"] = st.session_state["cleaned_data"][
st.session_state["cleaned_data"]["panel"].isin(selected_panels)
]
st.session_state["Cleaned_data_panel"] = (
st.session_state["Cleaned_data_panel"]
.groupby(by="date")
.agg(aggregation_dict)
)
st.session_state["Cleaned_data_panel"] = st.session_state[
"Cleaned_data_panel"
].reset_index()
else:
# st.write(st.session_state['cleaned_data'])
st.session_state["Cleaned_data_panel"] = (
st.session_state["cleaned_data"]
.groupby(by="date")
.agg(aggregation_dict)
)
st.session_state["Cleaned_data_panel"] = st.session_state[
"Cleaned_data_panel"
].reset_index()
fig = line_plot_target(
st.session_state["Cleaned_data_panel"],
target=target_column,
title=f"{target_column} Over Time",
)
st.plotly_chart(fig, use_container_width=True)
media_channel = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Media"
]
)
spends_features = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Spends"
]
)
# st.write(media_channel)
exo_var = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Exogenous"
]
)
internal_var = list(
*[
st.session_state["category_dict"][key]
for key in st.session_state["category_dict"].keys()
if key == "Internal"
]
)
Non_media_variables = exo_var + internal_var
st.markdown("### Annual Data Summary")
summary_df = summary(
st.session_state["Cleaned_data_panel"],
media_channel + [target_column] + spends_features,
spends=None,
Target=True,
)
st.dataframe(
summary_df.sort_index(axis=1),
use_container_width=True,
)
if st.checkbox("View Raw Data"):
st.cache_resource(show_spinner=False)
def raw_df_gen():
# Convert 'date' to datetime but do not convert to string yet for sorting
dates = pd.to_datetime(st.session_state["Cleaned_data_panel"]["date"])
# Concatenate the dates with other numeric columns formatted
raw_df = pd.concat(
[
dates,
st.session_state["Cleaned_data_panel"]
.select_dtypes(np.number)
.applymap(format_numbers),
],
axis=1,
)
# Now sort raw_df by the 'date' column, which is still in datetime format
sorted_raw_df = raw_df.sort_values(by="date", ascending=True)
# After sorting, convert 'date' to string format for display
sorted_raw_df["date"] = sorted_raw_df["date"].dt.strftime("%m/%d/%Y")
return sorted_raw_df
# Display the sorted DataFrame in Streamlit
st.dataframe(raw_df_gen())
col1 = st.columns(1)
if "selected_feature" not in st.session_state:
st.session_state["selected_feature"] = None
# st.warning('Work in Progress')
with st.expander("Media Variables Analysis"):
# Get the selected feature
st.session_state["selected_feature"] = st.selectbox(
"Select Media", media_channel + spends_features, format_func=format_display
)
# st.write(st.session_state["selected_feature"].split('cnt_')[1] )
# st.session_state["project_dct"]["data_validation"]["selected_feature"] = (
# )
# Filter spends features based on the selected feature
spends_col = st.columns(2)
spends_feature = [
col
for col in spends_features
if re.split(r"cost_|spends_", col.lower())[1]
in st.session_state["selected_feature"]
]
with spends_col[0]:
if len(spends_feature) == 0:
st.warning(
"The selected metric does not include a 'spends' variable in the data. Please verify that the columns are correctly named or select the appropriate columns in the provided selection box."
)
else:
st.write(
f'Selected "{spends_feature[0]}" as the corresponding spends variable automatically. If this is incorrect, please click the checkbox to change the variable.'
)
with spends_col[1]:
if len(spends_feature) == 0 or st.checkbox(
'Select "Spends" variable for CPM and CPC calculation'
):
spends_feature = [st.selectbox("Spends Variable", spends_features)]
if "validation" not in st.session_state:
st.session_state["validation"] = st.session_state["project_dct"][
"data_validation"
]["validated_variables"]
val_variables = [col for col in media_channel if col != "date"]
if not set(
st.session_state["project_dct"]["data_validation"]["validated_variables"]
).issubset(set(val_variables)):
st.session_state["validation"] = []
else:
fig_row1 = line_plot(
st.session_state["Cleaned_data_panel"],
x_col="date",
y1_cols=[st.session_state["selected_feature"]],
y2_cols=[target_column],
title=f'Analysis of {st.session_state["selected_feature"]} and {[target_column][0]} Over Time',
)
st.plotly_chart(fig_row1, use_container_width=True)
st.markdown("### Summary")
st.dataframe(
summary(
st.session_state["Cleaned_data_panel"],
[st.session_state["selected_feature"]],
spends=spends_feature[0],
),
use_container_width=True,
)
cols2 = st.columns(2)
if len(
set(st.session_state["validation"]).intersection(val_variables)
) == len(val_variables):
disable = True
help = "All media variables are validated"
else:
disable = False
help = ""
with cols2[0]:
if st.button("Validate", disabled=disable, help=help):
st.session_state["validation"].append(
st.session_state["selected_feature"]
)
with cols2[1]:
if st.checkbox("Validate All", disabled=disable, help=help):
st.session_state["validation"].extend(val_variables)
st.success("All media variables are validated ✅")
if len(
set(st.session_state["validation"]).intersection(val_variables)
) != len(val_variables):
validation_data = pd.DataFrame(
{
"Validate": [
(True if col in st.session_state["validation"] else False)
for col in val_variables
],
"Variables": val_variables,
}
)
sorted_validation_df = validation_data.sort_values(
by="Variables", ascending=True, na_position="first"
)
cols3 = st.columns([1, 30])
with cols3[1]:
validation_df = st.data_editor(
sorted_validation_df,
# column_config={
# 'Validate':st.column_config.CheckboxColumn(wi)
# },
column_config={
"Validate": st.column_config.CheckboxColumn(
default=False,
width=100,
),
"Variables": st.column_config.TextColumn(width=1000),
},
hide_index=True,
)
selected_rows = validation_df[validation_df["Validate"] == True][
"Variables"
]
# st.write(selected_rows)
st.session_state["validation"].extend(selected_rows)
st.session_state["project_dct"]["data_validation"][
"validated_variables"
] = st.session_state["validation"]
not_validated_variables = [
col
for col in val_variables
if col not in st.session_state["validation"]
]
if not_validated_variables:
not_validated_message = f'The following variables are not validated:\n{" , ".join(not_validated_variables)}'
st.warning(not_validated_message)
with st.expander("Non-Media Variables Analysis"):
if len(Non_media_variables) == 0:
st.warning("Non-Media variables not present")
else:
selected_columns_row4 = st.selectbox(
"Select Channel",
Non_media_variables,
format_func=format_display,
index=st.session_state["project_dct"]["data_validation"][
"Non_media_variables"
],
)
st.session_state["project_dct"]["data_validation"][
"Non_media_variables"
] = Non_media_variables.index(selected_columns_row4)
# # Create the dual-axis line plot
fig_row4 = line_plot(
st.session_state["Cleaned_data_panel"],
x_col="date",
y1_cols=[selected_columns_row4],
y2_cols=[target_column],
title=f"Analysis of {selected_columns_row4} and {target_column} Over Time",
)
st.plotly_chart(fig_row4, use_container_width=True)
selected_non_media = selected_columns_row4
sum_df = st.session_state["Cleaned_data_panel"][
["date", selected_non_media, target_column]
]
sum_df["Year"] = pd.to_datetime(
st.session_state["Cleaned_data_panel"]["date"]
).dt.year
# st.dataframe(df)
# st.dataframe(sum_df.head(2))
sum_df = sum_df.drop("date", axis=1).groupby("Year").agg("sum")
sum_df.loc["Grand Total"] = sum_df.sum()
sum_df = sum_df.applymap(format_numbers)
sum_df.fillna("-", inplace=True)
sum_df = sum_df.replace({"0.0": "-", "nan": "-"})
st.markdown("### Summary")
st.dataframe(sum_df, use_container_width=True)
with st.expander("Correlation Analysis"):
options = list(
st.session_state["Cleaned_data_panel"].select_dtypes(np.number).columns
)
if "correlation" not in st.session_state["project_dct"]["data_import"]:
st.session_state["project_dct"]["data_import"]["correlation"]=[]
selected_options = st.multiselect(
"Select Variables for Correlation Plot",
[var for var in options if var != target_column],
default=st.session_state["project_dct"]["data_import"]["correlation"],
)
st.session_state["project_dct"]["data_import"]["correlation"] = selected_options
st.pyplot(
correlation_plot(
st.session_state["Cleaned_data_panel"],
selected_options,
target_column,
)
)
if st.button("Save Changes", use_container_width=True):
# Update DB
update_db(
prj_id=st.session_state["project_number"],
page_nam="Data Validation and Insights",
file_nam="project_dct",
pkl_obj=pickle.dumps(st.session_state["project_dct"]),
schema=schema,
)
st.success("Changes saved")