Spaces:
Sleeping
Sleeping
import streamlit as st | |
import joblib | |
import pandas as pd | |
from PIL import Image | |
# Load the model and image | |
def load_model(): | |
return joblib.load("best_model.pkl") | |
def load_roc_image(): | |
return Image.open("roc_curve_rf_tuned.png") | |
try: | |
best_model = load_model() | |
roc_img = load_roc_image() | |
except Exception as e: | |
st.error(f"Error loading model or image: {str(e)}") | |
st.stop() | |
# App title and description | |
st.title("Customer Churn Prediction") | |
st.write("Enter customer information to predict likelihood of churn") | |
# Create two columns for inputs | |
col1, col2 = st.columns(2) | |
with col1: | |
age = st.slider("Age", min_value=18, max_value=100, value=40) | |
gender = st.selectbox("Gender", options=["Male", "Female"]) | |
tenure = st.slider("Tenure (months)", min_value=1, max_value=60, value=30) | |
usage_frequency = st.slider("Usage Frequency", min_value=1, max_value=30, value=15) | |
support_calls = st.slider("Support Calls", min_value=0, max_value=10, value=4) | |
with col2: | |
payment_delay = st.slider("Payment Delay", min_value=0, max_value=30, value=15) | |
last_interaction = st.slider("Last Interaction (days ago)", min_value=1, max_value=30, value=15) | |
total_spend = st.slider("Total Spend", min_value=100, max_value=1000, value=620) | |
subscription_type = st.selectbox("Subscription Type", options=["Premium", "Standard", "Basic"]) | |
contract_length = st.selectbox("Contract Length", options=["Monthly", "Quarterly", "Annual"]) | |
# Prediction function | |
def make_prediction(): | |
input_data = { | |
"Age": age, | |
"Gender_Male": 1 if gender == "Male" else 0, | |
"Gender_Female": 1 if gender == "Female" else 0, | |
"Usage Frequency": usage_frequency, | |
"Support Calls": support_calls, | |
"Contract Length_Monthly": 1 if contract_length == "Monthly" else 0, | |
"Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0, | |
"Contract Length_Annual": 1 if contract_length == "Annual" else 0, | |
"Payment Delay": payment_delay, | |
"Last Interaction": last_interaction, | |
"Total Spend": total_spend, | |
"Tenure": tenure, | |
"Subscription Type_Basic": 1 if subscription_type == "Basic" else 0, | |
"Subscription Type_Premium": 1 if subscription_type == "Premium" else 0, | |
"Subscription Type_Standard": 1 if subscription_type == "Standard" else 0, | |
} | |
input_df = pd.DataFrame([input_data]) | |
# Predict churn and probability | |
prediction = best_model.predict(input_df) | |
prediction_proba = best_model.predict_proba(input_df)[:, 1] | |
return prediction[0], prediction_proba[0] | |
# Make prediction when button is clicked | |
if st.button("Predict Churn"): | |
try: | |
prediction, probability = make_prediction() | |
# Display results | |
st.header("Prediction Results") | |
# Create three columns for results | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Churn Prediction", "Yes" if prediction == 1 else "No") | |
with col2: | |
st.metric("Churn Probability", f"{probability:.2f}") | |
with col3: | |
risk_level = "High" if probability > 0.7 else ("Medium" if probability > 0.4 else "Low") | |
st.metric("Risk Level", risk_level) | |
# Display ROC curve | |
st.subheader("Model ROC Curve") | |
st.image(roc_img, caption="ROC Curve for Random Forest Model") | |
except Exception as e: | |
st.error(f"Error making prediction: {str(e)}") |