File size: 3,580 Bytes
c593dca
a0a2c37
 
 
 
c593dca
 
 
 
a0a2c37
c593dca
 
 
a0a2c37
c593dca
 
 
 
 
 
a0a2c37
c593dca
 
 
a0a2c37
c593dca
 
a0a2c37
c593dca
 
 
 
 
 
a0a2c37
c593dca
 
 
 
 
 
a0a2c37
c593dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0a2c37
c593dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import joblib
import pandas as pd
from PIL import Image

# Load the model and image
@st.cache_resource
def load_model():
    return joblib.load("best_model.pkl")

@st.cache_data
def load_roc_image():
    return Image.open("roc_curve_rf_tuned.png")

try:
    best_model = load_model()
    roc_img = load_roc_image()
except Exception as e:
    st.error(f"Error loading model or image: {str(e)}")
    st.stop()

# App title and description
st.title("Customer Churn Prediction")
st.write("Enter customer information to predict likelihood of churn")

# Create two columns for inputs
col1, col2 = st.columns(2)

with col1:
    age = st.slider("Age", min_value=18, max_value=100, value=40)
    gender = st.selectbox("Gender", options=["Male", "Female"])
    tenure = st.slider("Tenure (months)", min_value=1, max_value=60, value=30)
    usage_frequency = st.slider("Usage Frequency", min_value=1, max_value=30, value=15)
    support_calls = st.slider("Support Calls", min_value=0, max_value=10, value=4)

with col2:
    payment_delay = st.slider("Payment Delay", min_value=0, max_value=30, value=15)
    last_interaction = st.slider("Last Interaction (days ago)", min_value=1, max_value=30, value=15)
    total_spend = st.slider("Total Spend", min_value=100, max_value=1000, value=620)
    subscription_type = st.selectbox("Subscription Type", options=["Premium", "Standard", "Basic"])
    contract_length = st.selectbox("Contract Length", options=["Monthly", "Quarterly", "Annual"])

# Prediction function
def make_prediction():
    input_data = {
        "Age": age,
        "Gender_Male": 1 if gender == "Male" else 0,
        "Gender_Female": 1 if gender == "Female" else 0,
        "Usage Frequency": usage_frequency,
        "Support Calls": support_calls,
        "Contract Length_Monthly": 1 if contract_length == "Monthly" else 0,
        "Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0,
        "Contract Length_Annual": 1 if contract_length == "Annual" else 0,
        "Payment Delay": payment_delay,
        "Last Interaction": last_interaction,
        "Total Spend": total_spend,
        "Tenure": tenure,
        "Subscription Type_Basic": 1 if subscription_type == "Basic" else 0,
        "Subscription Type_Premium": 1 if subscription_type == "Premium" else 0,
        "Subscription Type_Standard": 1 if subscription_type == "Standard" else 0,
    }
    
    input_df = pd.DataFrame([input_data])
    
    # Predict churn and probability
    prediction = best_model.predict(input_df)
    prediction_proba = best_model.predict_proba(input_df)[:, 1]
    
    return prediction[0], prediction_proba[0]

# Make prediction when button is clicked
if st.button("Predict Churn"):
    try:
        prediction, probability = make_prediction()
        
        # Display results
        st.header("Prediction Results")
        
        # Create three columns for results
        col1, col2, col3 = st.columns(3)
        
        with col1:
            st.metric("Churn Prediction", "Yes" if prediction == 1 else "No")
            
        with col2:
            st.metric("Churn Probability", f"{probability:.2f}")
            
        with col3:
            risk_level = "High" if probability > 0.7 else ("Medium" if probability > 0.4 else "Low")
            st.metric("Risk Level", risk_level)
        
        # Display ROC curve
        st.subheader("Model ROC Curve")
        st.image(roc_img, caption="ROC Curve for Random Forest Model")
        
    except Exception as e:
        st.error(f"Error making prediction: {str(e)}")