3v324v23's picture
revise plot using plotly express
e55cbec
import streamlit as st
import pandas as pd
import pickle
import plotly.express as px
def predict():
with open('./model/xgb_tuned.pkl', 'rb') as file_1:
prediction_model = pickle.load(file_1)
with open('./model/model_kmeans.pkl', 'rb') as file_2:
cluster_model = pickle.load(file_2)
#biodata
st.subheader("Customer Biodata")
col1, col2, col3 = st.columns(3)
gender = col1.radio(label="Gender", options=["M", "F"])
marital = col2.selectbox(label="Marital Status", options=["Single", "Married", "Divorced", "Unkown"])
relationship = col3.number_input(label="No. of Relationshop", step=1, value=4)
col1, col2, col3 = st.columns(3)
education = col1.selectbox(label="Education Level", options=["Graduate", "Unedicated", "High School", "College", "Post=Graduate", "Doctorate"])
income = col2.selectbox(label="Income Bracket", options=["Less than $40K", "$40K - $60K", "$60K - $80K", "$80K - $120K", "$120K+", "Unkown"])
card = col3.radio(label="Card Type", options=["Blue", "Silver", "Gold", "Platinum"])
#behaviour
st.subheader("Customer Behaviour")
col1, col2, col3 = st.columns(3)
active_months = col1.slider(label="Months Inactive", max_value=6, value=2)
last_contact = col2.slider(label="Last Contact", max_value=6, value=2)
credit_limit = col3.number_input(label="Credit Limit", step=0.1, value=10834.0)
col1, col2, col3 = st.columns(3)
revolving_bal = col1.number_input(label="Revolving Balance", step=1, value=0)
transaction_count = col2.number_input(label="Transaction Count", step=1, value=51)
transaction_amount = col3.number_input(label="Transaction Amount", step=1, value=2249)
col1, col2, col3 = st.columns(3)
amount_chng = col1.slider(label="Amount Change",step=0.01,min_value=0.0, max_value=5.0, value=0.522)
count_change = col2.slider(label="Count Change", step=0.01,min_value=0.0, max_value=5.0, value=0.594)
util_ratio = col3.slider(label="Utilization Ratio", step=0.01,min_value=0.0, max_value=1.0, value=0.0)
submit = st.button(label="Predict")
if submit:
data_inf = {
'Total_Relationship_Count': [relationship],
'Months_Inactive_12_mon': [active_months],
'Contacts_Count_12_mon': [last_contact],
'Credit_Limit': [credit_limit],
'Total_Revolving_Bal': [revolving_bal],
'Total_Amt_Chng_Q4_Q1': [amount_chng],
'Total_Trans_Amt': [transaction_amount],
'Total_Trans_Ct': [transaction_count],
'Total_Ct_Chng_Q4_Q1': [count_change],
'Avg_Utilization_Ratio': [util_ratio],
'Gender': [gender],
'Education_Level': [education],
'Marital_Status': [marital],
'Income_Category': [income],
'Card_Category': [card],
}
data_inf= pd.DataFrame(data_inf)
pred_inf = prediction_model.predict(data_inf)
if pred_inf == 0:
pred_inf = "Attrited Customer"
color = "red"
else:
pred_inf = "Existing Customer"
color = "green"
if pred_inf == "Attrited Customer":
cluster_inf = cluster_model.predict(data_inf)
else:
cluster_inf = None
if cluster_inf == 0:
cluster_inf = "Cluster 1 : High Spent Amount, High Usage Frequency"
recommendation = '<ul><li>1. Rewards and Recognition</li><li>2. Personalized Financial Solution</li><li>3. Financial Education</li></ul>'
elif cluster_inf == 1:
cluster_inf = "Cluster 2 : Low Spent Amount, Low Usage Frequency"
recommendation = '<ul><li>1. Improved Credit Opportunities</li><li>2. Value Propositions</li><li>3. Fee Structure Transparency</li><li>4. Financial Planning Assistance</li></ul> '
result_html = """
<div style="background-color:#f0f0f0; padding:10px; border-radius:10px">
<p style="font-size:16px;"><b>Customer Information:</b></p>
<div style="margin-top: 20px;">
</div>
<p>Customer is predicted to be <span style="color:{color};"><b>{pred_inf}</b></span>, and belongs to <span style="color:blue;"><b>{cluster_inf}</b></span>.</p>
<p><b>Here are some recommendations to help reduce churn among customers in corresponding clusters:</b></p>
{step}
</div>
"""
st.markdown(result_html.format(pred_inf=pred_inf, cluster_inf=cluster_inf, color=color, step=recommendation), unsafe_allow_html=True)
def cluster():
clusters = pd.read_csv('./csv/Cluster.csv')
bank_df_pca = pd.read_csv('./csv/BankPCA.csv')
colors = {0: 'navy', 1: 'teal'}
names = {0: 'High Spent Amount (>4K), High Usage Frequency',
1: 'Low Spent Amount (<4K), Low Usage Frequency'}
bank_df_pca['color'] = bank_df_pca['label'].map(colors)
bank_df_pca['name'] = bank_df_pca['label'].map(names)
fig = px.scatter(bank_df_pca, x='x', y='y', color='name', hover_name='name',
title='Churn Customer Clustering', width=800, height=400, )
fig.update_traces(marker=dict(size=5))
fig.update_layout(showlegend=True)
fig.update_layout(height=600)
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
predict()