DSA_Project / app.py
saranimje's picture
Update app.py
c593dca verified
import streamlit as st
import joblib
import pandas as pd
from PIL import Image
# Load the model and image
@st.cache_resource
def load_model():
return joblib.load("best_model.pkl")
@st.cache_data
def load_roc_image():
return Image.open("roc_curve_rf_tuned.png")
try:
best_model = load_model()
roc_img = load_roc_image()
except Exception as e:
st.error(f"Error loading model or image: {str(e)}")
st.stop()
# App title and description
st.title("Customer Churn Prediction")
st.write("Enter customer information to predict likelihood of churn")
# Create two columns for inputs
col1, col2 = st.columns(2)
with col1:
age = st.slider("Age", min_value=18, max_value=100, value=40)
gender = st.selectbox("Gender", options=["Male", "Female"])
tenure = st.slider("Tenure (months)", min_value=1, max_value=60, value=30)
usage_frequency = st.slider("Usage Frequency", min_value=1, max_value=30, value=15)
support_calls = st.slider("Support Calls", min_value=0, max_value=10, value=4)
with col2:
payment_delay = st.slider("Payment Delay", min_value=0, max_value=30, value=15)
last_interaction = st.slider("Last Interaction (days ago)", min_value=1, max_value=30, value=15)
total_spend = st.slider("Total Spend", min_value=100, max_value=1000, value=620)
subscription_type = st.selectbox("Subscription Type", options=["Premium", "Standard", "Basic"])
contract_length = st.selectbox("Contract Length", options=["Monthly", "Quarterly", "Annual"])
# Prediction function
def make_prediction():
input_data = {
"Age": age,
"Gender_Male": 1 if gender == "Male" else 0,
"Gender_Female": 1 if gender == "Female" else 0,
"Usage Frequency": usage_frequency,
"Support Calls": support_calls,
"Contract Length_Monthly": 1 if contract_length == "Monthly" else 0,
"Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0,
"Contract Length_Annual": 1 if contract_length == "Annual" else 0,
"Payment Delay": payment_delay,
"Last Interaction": last_interaction,
"Total Spend": total_spend,
"Tenure": tenure,
"Subscription Type_Basic": 1 if subscription_type == "Basic" else 0,
"Subscription Type_Premium": 1 if subscription_type == "Premium" else 0,
"Subscription Type_Standard": 1 if subscription_type == "Standard" else 0,
}
input_df = pd.DataFrame([input_data])
# Predict churn and probability
prediction = best_model.predict(input_df)
prediction_proba = best_model.predict_proba(input_df)[:, 1]
return prediction[0], prediction_proba[0]
# Make prediction when button is clicked
if st.button("Predict Churn"):
try:
prediction, probability = make_prediction()
# Display results
st.header("Prediction Results")
# Create three columns for results
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Churn Prediction", "Yes" if prediction == 1 else "No")
with col2:
st.metric("Churn Probability", f"{probability:.2f}")
with col3:
risk_level = "High" if probability > 0.7 else ("Medium" if probability > 0.4 else "Low")
st.metric("Risk Level", risk_level)
# Display ROC curve
st.subheader("Model ROC Curve")
st.image(roc_img, caption="ROC Curve for Random Forest Model")
except Exception as e:
st.error(f"Error making prediction: {str(e)}")