Spaces:
Configuration error
Configuration error
import streamlit as st | |
import pandas as pd | |
import joblib | |
from huggingface_hub import hf_hub_download | |
from modelConnector import ModelConnector | |
# =========================== | |
# LOAD MODEL & DATASET | |
# =========================== | |
st.title("π Is Click Predictor") | |
# Download and load the trained model from Hugging Face | |
model_path = hf_hub_download(repo_id="taimax13/is_click_predictor", filename="rf_model.pkl") | |
rf_model = joblib.load(model_path) | |
st.success("β Model Loaded Successfully!") | |
# =========================== | |
# LOAD DATA FROM HUGGING FACE | |
# =========================== | |
st.sidebar.header("Dataset Selection") | |
# # Download required dataset files | |
# X_test_path = hf_hub_download(repo_id="taimax13/is_click_data", filename="X_test_1st(1).csv") | |
# y_test_path = hf_hub_download(repo_id="taimax13/is_click_data", filename="y_test_1st.csv") | |
# train_data_path = hf_hub_download(repo_id="taimax13/is_click_data", filename="train_dataset_full - train_dataset_full (1).csv") | |
X_test_path = "HuggingFaceRepo/data/y_test_1st (1).csv" | |
y_test_path = "HuggingFaceRepo/data/y_test_1st.csv" | |
train_data_path = "HuggingFaceRepo/data/train_dataset_full - train_dataset_full.csv" | |
# Load datasets | |
X_test = pd.read_csv(X_test_path) | |
y_test = pd.read_csv(y_test_path, header=None) # Ensure labels match test dataset index | |
train_data = pd.read_csv(train_data_path) | |
st.info(f"β Loaded datasets: **Train: {len(train_data)} rows**, **Test: {len(X_test)} rows**") | |
# Initialize Model Connector | |
model_connector = ModelConnector() | |
st.title("π Is Click Predictor - Train, Retrain, and Predict") | |
# =========================== | |
# CHECK MODEL STATUS | |
# =========================== | |
if model_connector.model: | |
st.success("β Model Loaded Successfully!") | |
else: | |
st.warning("β No model found. Please train one first.") | |
# =========================== | |
# TRAIN MODEL IF NOT FOUND | |
# =========================== | |
if st.button("π Train Model"): | |
st.info("π Training model...") | |
message = model_connector.train_model() | |
st.success(message) | |
# =========================== | |
# RETRAIN MODEL | |
# =========================== | |
if st.button("π Retrain Model"): | |
st.info("π Retraining model with latest data...") | |
message = model_connector.retrain_model() | |
st.success(message) | |
# =========================== | |
# SELECT A DATA SAMPLE | |
# =========================== | |
st.sidebar.header("Select a Test Sample for Prediction") | |
# Merge X_test with y_test for selection (without labels affecting prediction) | |
X_test["actual_click"] = y_test.values | |
# Allow user to pick a row | |
selected_index = st.sidebar.selectbox("Choose a test sample index", X_test.index) | |
selected_row = X_test.loc[selected_index].drop("actual_click") # Exclude actual label | |
# Display selected row | |
st.write("### Selected Data Sample:") | |
st.dataframe(selected_row.to_frame().T) # Display as a table | |
# =========================== | |
# MAKE PREDICTION & EXPORT CSV | |
# =========================== | |
if st.button("Predict Click"): | |
# Convert selected row to DataFrame for model input | |
input_data = selected_row.to_frame().T | |
# Make prediction | |
prediction = rf_model.predict(input_data)[0] | |
# Add prediction to DataFrame | |
input_data["is_click_predicted"] = prediction | |
# Save prediction as CSV | |
csv_filename = "prediction_result.csv" | |
input_data.to_csv(csv_filename, index=False) | |
# Display Prediction Result | |
st.subheader("Prediction Result") | |
if prediction == 1: | |
st.success("π’ The model predicts: **User WILL CLICK on the ad!**") | |
else: | |
st.warning("π΄ The model predicts: **User WILL NOT CLICK on the ad.**") | |
# Provide download button for prediction result | |
st.download_button( | |
label="π₯ Download Prediction Result", | |
data=input_data.to_csv(index=False).encode("utf-8"), | |
file_name="prediction_result.csv", | |
mime="text/csv", | |
) | |
st.markdown("---") | |
st.info("Select a test row from the **left panel**, click **'Predict Click'**, and download the prediction result as a CSV.") | |