import streamlit as st
from numerize.numerize import numerize
import numpy as np
from functools import partial
from collections import OrderedDict
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from utilities import format_numbers,load_local_css,set_header,initialize_data,load_authenticator,send_email,channel_name_formating
from classes import class_from_dict,class_to_dict
import pickle
import streamlit_authenticator as stauth
import yaml
from yaml import SafeLoader
import re
import pandas as pd
import plotly.express as px
target='Revenue'
st.set_page_config(layout='wide')
load_local_css('styles.css')
set_header()
for k, v in st.session_state.items():
if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'):
st.session_state[k] = v
# ======================================================== #
# ======================= Functions ====================== #
# ======================================================== #
def optimize():
"""
Optimize the spends for the sales
"""
channel_list = [key for key,value in st.session_state['optimization_channels'].items() if value]
print('channel_list')
print(channel_list)
print('@@@@@@@@')
if len(channel_list) > 0 :
scenario = st.session_state['scenario']
result = st.session_state['scenario'].optimize(st.session_state['total_spends_change'],channel_list)
for channel_name, modified_spends in result:
st.session_state[channel_name] = numerize(modified_spends * scenario.channels[channel_name].conversion_rate,1)
prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends
st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
def save_scenario(scenario_name):
"""
Save the current scenario with the mentioned name in the session state
Parameters
----------
scenario_name
Name of the scenario to be saved
"""
if 'saved_scenarios' not in st.session_state:
st.session_state = OrderedDict()
#st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save()
st.session_state['saved_scenarios'][scenario_name] = class_to_dict(st.session_state['scenario'])
st.session_state['scenario_input'] = ""
print(type(st.session_state['saved_scenarios']))
with open('../saved_scenarios.pkl', 'wb') as f:
pickle.dump(st.session_state['saved_scenarios'],f)
def update_all_spends():
"""
Updates spends for all the channels with the given overall spends change
"""
percent_change = st.session_state['total_spends_change']
for channel_name in st.session_state['channels_list']:
channel = st.session_state['scenario'].channels[channel_name]
current_spends = channel.actual_total_spends
modified_spends = (1 + percent_change/100) * current_spends
st.session_state['scenario'].update(channel_name, modified_spends)
st.session_state[channel_name] = numerize(modified_spends*channel.conversion_rate,1)
st.session_state[f'{channel_name}_change'] = percent_change
def extract_number_for_string(string_input):
string_input = string_input.upper()
if string_input.endswith('K'):
return float(string_input[:-1])*10**3
elif string_input.endswith('M'):
return float(string_input[:-1])*10**6
elif string_input.endswith('B'):
return float(string_input[:-1])*10**9
def validate_input(string_input):
pattern = r'\d+\.?\d*[K|M|B]$'
match = re.match(pattern, string_input)
if match is None:
return False
return True
def update_data_by_percent(channel_name):
prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
modified_spends = prev_spends * (1 + st.session_state[f'{channel_name}_change']/100)
st.session_state[channel_name] = numerize(modified_spends,1)
st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
def update_data(channel_name):
"""
Updates the spends for the given channel
"""
if validate_input(st.session_state[channel_name]):
modified_spends = extract_number_for_string(st.session_state[channel_name])
prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
# st.session_state['scenario'].update(channel_name, modified_spends)
# else:
# try:
# modified_spends = float(st.session_state[channel_name])
# prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate
# st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2)
# st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate)
# st.session_state[f'{channel_name}'] = numerize(modified_spends,1)
# except ValueError:
# st.write('Invalid input')
def select_channel_for_optimization(channel_name):
"""
Marks the given channel for optimization
"""
st.session_state['optimization_channels'][channel_name] = st.session_state[f'{channel_name}_selected']
def select_all_channels_for_optimization():
"""
Marks all the channel for optimization
"""
for channel_name in st.session_state['optimization_channels'].keys():
st.session_state[f'{channel_name}_selected' ] = st.session_state['optimze_all_channels']
st.session_state['optimization_channels'][channel_name] = st.session_state['optimze_all_channels']
def update_penalty():
"""
Updates the penalty flag for sales calculation
"""
st.session_state['scenario'].update_penalty(st.session_state['apply_penalty'])
def reset_scenario():
# print(st.session_state['default_scenario_dict'])
# st.session_state['scenario'] = class_from_dict(st.session_state['default_scenario_dict'])
# for channel in st.session_state['scenario'].channels.values():
# st.session_state[channel.name] = float(channel.actual_total_spends * channel.conversion_rate)
initialize_data()
for channel_name in st.session_state['channels_list']:
st.session_state[f'{channel_name}_selected'] = False
st.session_state[f'{channel_name}_change'] = 0
st.session_state['optimze_all_channels'] = False
def format_number(num):
if num >= 1_000_000:
return f"{num / 1_000_000:.2f}M"
elif num >= 1_000:
return f"{num / 1_000:.0f}K"
else:
return f"{num:.2f}"
def summary_plot(data, x, y, title, text_column):
fig = px.bar(data, x=x, y=y, orientation='h',
title=title, text=text_column, color='Channel_name')
# Convert text_column to numeric values
data[text_column] = pd.to_numeric(data[text_column], errors='coerce')
# Update the format of the displayed text based on magnitude
fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}')
fig.update_layout(xaxis_title=x, yaxis_title='Channel Name', showlegend=False)
return fig
def s_curve(x,K,b,a,x0):
return K / (1 + b*np.exp(-a*(x-x0)))
@st.cache
def plot_response_curves():
cols=4
rcs = st.session_state['rcs']
shapes = []
fig = make_subplots(rows=6, cols=cols,subplot_titles=channels_list)
for i in range(0, len(channels_list)):
col = channels_list[i]
x = st.session_state['actual_df'][col].values
spends = x.sum()
power = (np.ceil(np.log(x.max()) / np.log(10) )- 3)
x = np.linspace(0,3*x.max(),200)
K = rcs[col]['K']
b = rcs[col]['b']
a = rcs[col]['a']
x0 = rcs[col]['x0']
y = s_curve(x/10**power,K,b,a,x0)
roi = y/x
marginal_roi = a * (y)*(1-y/K)
fig.add_trace(
go.Scatter(x=52*x*st.session_state['scenario'].channels[col].conversion_rate,
y=52*y,
name=col,
customdata = np.stack((roi, marginal_roi),axis=-1),
hovertemplate="Spend:%{x:$.2s}
Sale:%{y:$.2s}
ROI:%{customdata[0]:.3f}
MROI:%{customdata[1]:.3f}"),
row=1+(i)//cols , col=i%cols + 1
)
fig.add_trace(go.Scatter(x=[spends*st.session_state['scenario'].channels[col].conversion_rate],
y=[52*s_curve(spends/(10**power*52),K,b,a,x0)],
name=col,
legendgroup=col,
showlegend=False,
marker=dict(color=['black'])),
row=1+(i)//cols , col=i%cols + 1)
shapes.append(go.layout.Shape(type="line",
x0=0,
y0=52*s_curve(spends/(10**power*52),K,b,a,x0),
x1=spends*st.session_state['scenario'].channels[col].conversion_rate,
y1=52*s_curve(spends/(10**power*52),K,b,a,x0),
line_width=1,
line_dash="dash",
line_color="black",
xref= f'x{i+1}',
yref= f'y{i+1}'))
shapes.append(go.layout.Shape(type="line",
x0=spends*st.session_state['scenario'].channels[col].conversion_rate,
y0=0,
x1=spends*st.session_state['scenario'].channels[col].conversion_rate,
y1=52*s_curve(spends/(10**power*52),K,b,a,x0),
line_width=1,
line_dash="dash",
line_color="black",
xref= f'x{i+1}',
yref= f'y{i+1}'))
fig.update_layout(height=1500, width=1000, title_text="Response Curves",showlegend=False,shapes=shapes)
fig.update_annotations(font_size=10)
fig.update_xaxes(title='Spends')
fig.update_yaxes(title=target)
return fig
# ======================================================== #
# ==================== HTML Components =================== #
# ======================================================== #
def generate_spending_header(heading):
return st.markdown(f"""