import streamlit as st from numerize.numerize import numerize import numpy as np from functools import partial from collections import OrderedDict from plotly.subplots import make_subplots import plotly.graph_objects as go from utilities import format_numbers,load_local_css,set_header,initialize_data,load_authenticator,send_email,channel_name_formating from classes import class_from_dict,class_to_dict import pickle import streamlit_authenticator as stauth import yaml from yaml import SafeLoader import re import pandas as pd import plotly.express as px target='Revenue' st.set_page_config(layout='wide') load_local_css('styles.css') set_header() for k, v in st.session_state.items(): if k not in ['logout', 'login','config'] and not k.startswith('FormSubmitter'): st.session_state[k] = v # ======================================================== # # ======================= Functions ====================== # # ======================================================== # def optimize(): """ Optimize the spends for the sales """ channel_list = [key for key,value in st.session_state['optimization_channels'].items() if value] print('channel_list') print(channel_list) print('@@@@@@@@') if len(channel_list) > 0 : scenario = st.session_state['scenario'] result = st.session_state['scenario'].optimize(st.session_state['total_spends_change'],channel_list) for channel_name, modified_spends in result: st.session_state[channel_name] = numerize(modified_spends * scenario.channels[channel_name].conversion_rate,1) prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2) def save_scenario(scenario_name): """ Save the current scenario with the mentioned name in the session state Parameters ---------- scenario_name Name of the scenario to be saved """ if 'saved_scenarios' not in st.session_state: st.session_state = OrderedDict() #st.session_state['saved_scenarios'][scenario_name] = st.session_state['scenario'].save() st.session_state['saved_scenarios'][scenario_name] = class_to_dict(st.session_state['scenario']) st.session_state['scenario_input'] = "" print(type(st.session_state['saved_scenarios'])) with open('../saved_scenarios.pkl', 'wb') as f: pickle.dump(st.session_state['saved_scenarios'],f) def update_all_spends(): """ Updates spends for all the channels with the given overall spends change """ percent_change = st.session_state['total_spends_change'] for channel_name in st.session_state['channels_list']: channel = st.session_state['scenario'].channels[channel_name] current_spends = channel.actual_total_spends modified_spends = (1 + percent_change/100) * current_spends st.session_state['scenario'].update(channel_name, modified_spends) st.session_state[channel_name] = numerize(modified_spends*channel.conversion_rate,1) st.session_state[f'{channel_name}_change'] = percent_change def extract_number_for_string(string_input): string_input = string_input.upper() if string_input.endswith('K'): return float(string_input[:-1])*10**3 elif string_input.endswith('M'): return float(string_input[:-1])*10**6 elif string_input.endswith('B'): return float(string_input[:-1])*10**9 def validate_input(string_input): pattern = r'\d+\.?\d*[K|M|B]$' match = re.match(pattern, string_input) if match is None: return False return True def update_data_by_percent(channel_name): prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate modified_spends = prev_spends * (1 + st.session_state[f'{channel_name}_change']/100) st.session_state[channel_name] = numerize(modified_spends,1) st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate) def update_data(channel_name): """ Updates the spends for the given channel """ if validate_input(st.session_state[channel_name]): modified_spends = extract_number_for_string(st.session_state[channel_name]) prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2) st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate) # st.session_state['scenario'].update(channel_name, modified_spends) # else: # try: # modified_spends = float(st.session_state[channel_name]) # prev_spends = st.session_state['scenario'].channels[channel_name].actual_total_spends * st.session_state['scenario'].channels[channel_name].conversion_rate # st.session_state[f'{channel_name}_change'] = round(100*(modified_spends - prev_spends) / prev_spends,2) # st.session_state['scenario'].update(channel_name, modified_spends/st.session_state['scenario'].channels[channel_name].conversion_rate) # st.session_state[f'{channel_name}'] = numerize(modified_spends,1) # except ValueError: # st.write('Invalid input') def select_channel_for_optimization(channel_name): """ Marks the given channel for optimization """ st.session_state['optimization_channels'][channel_name] = st.session_state[f'{channel_name}_selected'] def select_all_channels_for_optimization(): """ Marks all the channel for optimization """ for channel_name in st.session_state['optimization_channels'].keys(): st.session_state[f'{channel_name}_selected' ] = st.session_state['optimze_all_channels'] st.session_state['optimization_channels'][channel_name] = st.session_state['optimze_all_channels'] def update_penalty(): """ Updates the penalty flag for sales calculation """ st.session_state['scenario'].update_penalty(st.session_state['apply_penalty']) def reset_scenario(): # print(st.session_state['default_scenario_dict']) # st.session_state['scenario'] = class_from_dict(st.session_state['default_scenario_dict']) # for channel in st.session_state['scenario'].channels.values(): # st.session_state[channel.name] = float(channel.actual_total_spends * channel.conversion_rate) initialize_data() for channel_name in st.session_state['channels_list']: st.session_state[f'{channel_name}_selected'] = False st.session_state[f'{channel_name}_change'] = 0 st.session_state['optimze_all_channels'] = False def format_number(num): if num >= 1_000_000: return f"{num / 1_000_000:.2f}M" elif num >= 1_000: return f"{num / 1_000:.0f}K" else: return f"{num:.2f}" def summary_plot(data, x, y, title, text_column): fig = px.bar(data, x=x, y=y, orientation='h', title=title, text=text_column, color='Channel_name') # Convert text_column to numeric values data[text_column] = pd.to_numeric(data[text_column], errors='coerce') # Update the format of the displayed text based on magnitude fig.update_traces(texttemplate='%{text:.2s}', textposition='outside', hovertemplate='%{x:.2s}') fig.update_layout(xaxis_title=x, yaxis_title='Channel Name', showlegend=False) return fig def s_curve(x,K,b,a,x0): return K / (1 + b*np.exp(-a*(x-x0))) @st.cache def plot_response_curves(): cols=4 rcs = st.session_state['rcs'] shapes = [] fig = make_subplots(rows=6, cols=cols,subplot_titles=channels_list) for i in range(0, len(channels_list)): col = channels_list[i] x = st.session_state['actual_df'][col].values spends = x.sum() power = (np.ceil(np.log(x.max()) / np.log(10) )- 3) x = np.linspace(0,3*x.max(),200) K = rcs[col]['K'] b = rcs[col]['b'] a = rcs[col]['a'] x0 = rcs[col]['x0'] y = s_curve(x/10**power,K,b,a,x0) roi = y/x marginal_roi = a * (y)*(1-y/K) fig.add_trace( go.Scatter(x=52*x*st.session_state['scenario'].channels[col].conversion_rate, y=52*y, name=col, customdata = np.stack((roi, marginal_roi),axis=-1), hovertemplate="Spend:%{x:$.2s}
Sale:%{y:$.2s}
ROI:%{customdata[0]:.3f}
MROI:%{customdata[1]:.3f}"), row=1+(i)//cols , col=i%cols + 1 ) fig.add_trace(go.Scatter(x=[spends*st.session_state['scenario'].channels[col].conversion_rate], y=[52*s_curve(spends/(10**power*52),K,b,a,x0)], name=col, legendgroup=col, showlegend=False, marker=dict(color=['black'])), row=1+(i)//cols , col=i%cols + 1) shapes.append(go.layout.Shape(type="line", x0=0, y0=52*s_curve(spends/(10**power*52),K,b,a,x0), x1=spends*st.session_state['scenario'].channels[col].conversion_rate, y1=52*s_curve(spends/(10**power*52),K,b,a,x0), line_width=1, line_dash="dash", line_color="black", xref= f'x{i+1}', yref= f'y{i+1}')) shapes.append(go.layout.Shape(type="line", x0=spends*st.session_state['scenario'].channels[col].conversion_rate, y0=0, x1=spends*st.session_state['scenario'].channels[col].conversion_rate, y1=52*s_curve(spends/(10**power*52),K,b,a,x0), line_width=1, line_dash="dash", line_color="black", xref= f'x{i+1}', yref= f'y{i+1}')) fig.update_layout(height=1500, width=1000, title_text="Response Curves",showlegend=False,shapes=shapes) fig.update_annotations(font_size=10) fig.update_xaxes(title='Spends') fig.update_yaxes(title=target) return fig # ======================================================== # # ==================== HTML Components =================== # # ======================================================== # def generate_spending_header(heading): return st.markdown(f"""

{heading}

""",unsafe_allow_html=True) # ======================================================== # # =================== Session variables ================== # # ======================================================== # with open('config.yaml') as file: config = yaml.load(file, Loader=SafeLoader) st.session_state['config'] = config authenticator = stauth.Authenticate( config['credentials'], config['cookie']['name'], config['cookie']['key'], config['cookie']['expiry_days'], config['preauthorized'] ) st.session_state['authenticator'] = authenticator name, authentication_status, username = authenticator.login('Login', 'main') auth_status = st.session_state.get('authentication_status') if auth_status == True: authenticator.logout('Logout', 'main') is_state_initiaized = st.session_state.get('initialized',False) if not is_state_initiaized: initialize_data() channels_list = st.session_state['channels_list'] # ======================================================== # # ========================== UI ========================== # # ======================================================== # print(list(st.session_state.keys())) st.header('Simulation') main_header = st.columns((2,2)) sub_header = st.columns((1,1,1,1)) _scenario = st.session_state['scenario'] with main_header[0]: st.subheader('Actual') with main_header[-1]: st.subheader('Simulated') with sub_header[0]: st.metric(label = 'Spends', value=format_numbers(_scenario.actual_total_spends)) with sub_header[1]: st.metric(label = target, value=format_numbers(float(_scenario.actual_total_sales),include_indicator=False)) with sub_header[2]: st.metric(label = 'Spends', value=format_numbers(_scenario.modified_total_spends), delta=numerize(_scenario.delta_spends,1)) with sub_header[3]: st.metric(label = target, value=format_numbers(float(_scenario.modified_total_sales),include_indicator=False), delta=numerize(_scenario.delta_sales,1)) with st.expander("Channel Spends Simulator"): _columns = st.columns((2,4,1,1)) with _columns[0]: st.checkbox(label='Optimize all Channels', key=f'optimze_all_channels', value=False, on_change=select_all_channels_for_optimization, ) st.number_input('Percent change of total spends', key=f'total_spends_change', step= 1, on_change=update_all_spends) with _columns[2]: st.button('Optimize',on_click=optimize) with _columns[3]: st.button('Reset',on_click=reset_scenario) st.markdown("""
""", unsafe_allow_html=True) _columns = st.columns((2.5,2,1.5,1.5,1)) with _columns[0]: generate_spending_header('Channel') with _columns[1]: generate_spending_header('Spends Input') with _columns[2]: generate_spending_header('Spends') with _columns[3]: generate_spending_header(target) with _columns[4]: generate_spending_header('Optimize') st.markdown("""
""", unsafe_allow_html=True) if 'acutual_predicted' not in st.session_state: st.session_state['acutual_predicted']={'Channel_name':[], 'Actual_spend':[], 'Optimized_spend':[], 'Delta':[] } for i,channel_name in enumerate(channels_list): _channel_class = st.session_state['scenario'].channels[channel_name] _columns = st.columns((2.5,1.5,1.5,1.5,1)) with _columns[0]: st.write(channel_name_formating(channel_name)) with _columns[1]: channel_bounds = _channel_class.bounds channel_spends = float(_channel_class.actual_total_spends ) min_value = float((1+channel_bounds[0]/100) * channel_spends ) max_value = float((1+channel_bounds[1]/100) * channel_spends ) #print(st.session_state[channel_name]) spend_input = st.text_input(channel_name, key=channel_name, label_visibility='collapsed', on_change=partial(update_data,channel_name)) if not validate_input(spend_input): st.error('Invalid input') st.number_input('Percent change', key=f'{channel_name}_change', step= 1, on_change=partial(update_data_by_percent,channel_name)) with _columns[2]: # spends current_channel_spends = float(_channel_class.modified_total_spends * _channel_class.conversion_rate) actual_channel_spends = float(_channel_class.actual_total_spends * _channel_class.conversion_rate) spends_delta = float(_channel_class.delta_spends * _channel_class.conversion_rate) st.session_state['acutual_predicted']['Channel_name'].append(channel_name) st.session_state['acutual_predicted']['Actual_spend'].append(actual_channel_spends) st.session_state['acutual_predicted']['Optimized_spend'].append(current_channel_spends) st.session_state['acutual_predicted']['Delta'].append(spends_delta) ## REMOVE st.metric('Spends', format_numbers(current_channel_spends), delta=numerize(spends_delta,1), label_visibility='collapsed') with _columns[3]: # sales current_channel_sales = float(_channel_class.modified_total_sales) actual_channel_sales = float(_channel_class.actual_total_sales) sales_delta = float(_channel_class.delta_sales) st.metric(target, format_numbers(current_channel_sales,include_indicator=False), delta=numerize(sales_delta,1), label_visibility='collapsed') with _columns[4]: st.checkbox(label='select for optimization', key=f'{channel_name}_selected', value=False, on_change=partial(select_channel_for_optimization,channel_name), label_visibility='collapsed') st.markdown("""
""",unsafe_allow_html=True) with st.expander("See Response Curves"): fig = plot_response_curves() st.plotly_chart(fig,use_container_width=True) _columns = st.columns(2) with _columns[0]: st.subheader('Save Scenario') scenario_name = st.text_input('Scenario name', key='scenario_input',placeholder='Scenario name',label_visibility='collapsed') st.button('Save', on_click=lambda : save_scenario(scenario_name),disabled=len(st.session_state['scenario_input']) == 0) summary_df=pd.DataFrame(st.session_state['acutual_predicted']) summary_df.drop_duplicates(subset='Channel_name',keep='last',inplace=True) summary_df_sorted = summary_df.sort_values(by='Delta', ascending=False) summary_df_sorted['Delta_percent'] = np.round(((summary_df_sorted['Optimized_spend'] / summary_df_sorted['Actual_spend'])-1) * 100, 2) with open("summary_df.pkl", "wb") as f: pickle.dump(summary_df_sorted, f) #st.dataframe(summary_df_sorted) # ___columns=st.columns(3) # with ___columns[2]: # fig=summary_plot(summary_df_sorted, x='Delta_percent', y='Channel_name', title='Delta', text_column='Delta_percent') # st.plotly_chart(fig,use_container_width=True) # with ___columns[0]: # fig=summary_plot(summary_df_sorted, x='Actual_spend', y='Channel_name', title='Actual Spend', text_column='Actual_spend') # st.plotly_chart(fig,use_container_width=True) # with ___columns[1]: # fig=summary_plot(summary_df_sorted, x='Optimized_spend', y='Channel_name', title='Planned Spend', text_column='Optimized_spend') # st.plotly_chart(fig,use_container_width=True) elif auth_status == False: st.error('Username/Password is incorrect') if auth_status != True: try: username_forgot_pw, email_forgot_password, random_password = authenticator.forgot_password('Forgot password') if username_forgot_pw: st.session_state['config']['credentials']['usernames'][username_forgot_pw]['password'] = stauth.Hasher([random_password]).generate()[0] send_email(email_forgot_password, random_password) st.success('New password sent securely') # Random password to be transferred to user securely elif username_forgot_pw == False: st.error('Username not found') except Exception as e: st.error(e)